diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index 7b64ea44cfa53459e475e622fa5afe8071f54d8b..e5df485eda1d8492a06c7860db21773e43951819 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import ( _is_flash_attention_supported, _is_fused_attention_supported, _is_unfused_attention_supported, - _run_dot_product_attention + _run_dot_product_attention, ) pd.set_option("display.precision", 4) @@ -28,7 +28,7 @@ ckpt_attn = False # workspace optimization path for cuDNN attention workspace_opt = True # QKV memory layout -qkv_layout = 'bshd_bshd_bshd' +qkv_layout = "bshd_bshd_bshd" # sliding window attention swa = False # padding between sequences for qkv_format=thd @@ -38,16 +38,17 @@ is_training = True model_configs = { # test: b, h, hg, d, sq, skv, p, mask, bias - "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq - "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask - "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias - "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA + "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq + "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask + "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias + "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA } + def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported): config = model_configs[model] if dtype == torch.bfloat16: - tols = dict(atol=2.5e-2, rtol=2.5e-2) + tols = dict(atol=2.5e-2, rtol=2.5e-2) else: tols = dict(atol=5e-3, rtol=5e-3) @@ -57,17 +58,31 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp for i in range(warmup_iters): if fused_attn_supported: fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( - dtype, config, "FusedAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) if flash_attn_supported: flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( - dtype, config, "FlashAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FlashAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) if fused_attn_supported and flash_attn_supported: torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) - for i,_ in enumerate(flash_attn_bwd): + for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.cuda.cudart().cudaProfilerStart() @@ -76,8 +91,15 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp if fused_attn_supported: for i in range(num_iters): fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( - dtype, config, "FusedAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) torch.cuda.synchronize() fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0 @@ -87,81 +109,113 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp if flash_attn_supported: for i in range(num_iters): flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( - dtype, config, "FlashAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FlashAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) torch.cuda.synchronize() flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0 - df = pd.read_csv('times.csv') - df = pd.concat([ - df, - pd.DataFrame( - [[fused_attn_time*1e3/num_iters, 0, 0, 0, - flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)], - ignore_index=True - ) - df.to_csv('times.csv',index=False) + df = pd.read_csv("times.csv") + df = pd.concat( + [ + df, + pd.DataFrame( + [ + [ + fused_attn_time * 1e3 / num_iters, + 0, + 0, + 0, + flash_attn_time * 1e3 / num_iters, + 0, + 0, + 0, + 0, + ] + ], + columns=df.columns, + ), + ], + ignore_index=True, + ) + df.to_csv("times.csv", index=False) torch.cuda.cudart().cudaProfilerStop() + def parse_results(per_cudnn, per_flash, model): - filename = f'prof_{model}_cuda_gpu_trace.csv' - df = pd.read_csv(os.path.join('./',filename)) - df_times = pd.read_csv('times.csv') - row = len(df_times.index)-1 - + filename = f"prof_{model}_cuda_gpu_trace.csv" + df = pd.read_csv(os.path.join("./", filename)) + df_times = pd.read_csv("times.csv") + row = len(df_times.index) - 1 + if per_cudnn > 0: - t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy() + t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy() t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn) t_cudnn_avg = np.average(t_cudnn_all, axis=0) - df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6 - df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6 - df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6 + df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6 + df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6 + df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6 if per_flash > 0: - t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy() + t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy() t_flash_all = t_flash_all.reshape(-1, per_flash) t_flash_avg = np.average(t_flash_all, axis=0) - df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6 - df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6 - df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6 + df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6 + df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6 + df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6 if per_cudnn > 0 and per_flash > 0: - df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \ - df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \ - df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] - df_times.to_csv('times.csv',index=False) + df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = ( + df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] + / df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] + ) + df_times.to_csv("times.csv", index=False) + def main(): times = pd.DataFrame( - columns=[ - 'FusedAttention Module', - 'FusedAttention Kernels (fwd)', - 'FusedAttention Kernels (bwd)', - 'FusedAttention Kernels (fwd+bwd)', - 'FlashAttention Module', - 'FlashAttention Kernels (fwd)', - 'FlashAttention Kernels (bwd)', - 'FlashAttention Kernels (fwd+bwd)', - 'Fused vs Flash Kernels Speedup (fwd+bwd)', - ]) - times.to_csv('times.csv',index=False) + columns=[ + "FusedAttention Module", + "FusedAttention Kernels (fwd)", + "FusedAttention Kernels (bwd)", + "FusedAttention Kernels (fwd+bwd)", + "FlashAttention Module", + "FlashAttention Kernels (fwd)", + "FlashAttention Kernels (bwd)", + "FlashAttention Kernels (fwd+bwd)", + "Fused vs Flash Kernels Speedup (fwd+bwd)", + ] + ) + times.to_csv("times.csv", index=False) device_id = torch.cuda.current_device() device_properties = torch.cuda.get_device_properties(device_id) - print(f"Device {device_id}: " + print( + f"Device {device_id}: " f"{device_properties.name} GPU, " f"sm{device_properties.major}{device_properties.minor} compute capability, " - f"{device_properties.total_memory/1024**3:.1f}GB memory") + f"{device_properties.total_memory/1024**3:.1f}GB memory" + ) for model in model_configs.keys(): config = model_configs[model] fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( - config, dtype, qkv_layout=qkv_layout, + config, + dtype, + qkv_layout=qkv_layout, ) fused_attn_supported = fused_attn_supported and not swa flash_attn_supported = _is_flash_attention_supported(config) - print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' - f'{" and flash-attention" if flash_attn_supported else ""}...') + print( + f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' + f'{" and flash-attention" if flash_attn_supported else ""}...' + ) prof_cmd = [ "nsys", @@ -175,8 +229,8 @@ def main(): f""" "import benchmark_attention;""", f"""benchmark_attention.benchmark_dot_product_attention(""" f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """, - ] - prof_cmd = ' '.join(prof_cmd) + ] + prof_cmd = " ".join(prof_cmd) subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) stats_cmd = [ "nsys", @@ -190,17 +244,17 @@ def main(): "--force-export=true", f"--output=prof_{model}", f"prof_{model}.nsys-rep", - ] + ] if fused_attn_supported: num_kernels_cudnn = 4 - if config.attn_bias_type == 'post_scale_bias': - num_kernels_cudnn = num_kernels_cudnn+1 + if config.attn_bias_type == "post_scale_bias": + num_kernels_cudnn = num_kernels_cudnn + 1 if config.num_heads != config.num_gqa_groups: - num_kernels_cudnn = num_kernels_cudnn+2 + num_kernels_cudnn = num_kernels_cudnn + 2 else: num_kernels_cudnn = 0 num_kernels_flash = 4 if flash_attn_supported else 0 - stats_cmd = ' '.join(stats_cmd) + stats_cmd = " ".join(stats_cmd) subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) parse_cmd = [ "python", @@ -208,18 +262,23 @@ def main(): f""" "import benchmark_attention;""", f"""benchmark_attention.parse_results(""" f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """, - ] - parse_cmd = ' '.join(parse_cmd) + ] + parse_cmd = " ".join(parse_cmd) subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) - df_times = pd.read_csv('times.csv') + df_times = pd.read_csv("times.csv") df_times.index = list(model_configs.keys()) - a=df_times[['FusedAttention Kernels (fwd+bwd)', - 'FlashAttention Kernels (fwd+bwd)', - 'Fused vs Flash Kernels Speedup (fwd+bwd)']] - a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup'] + a = df_times[ + [ + "FusedAttention Kernels (fwd+bwd)", + "FlashAttention Kernels (fwd+bwd)", + "Fused vs Flash Kernels Speedup (fwd+bwd)", + ] + ] + a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"] print() print(a) + if __name__ == "__main__": main() diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 170a8da843cc833f9848d2e8c5f3d1ef51d65582..73414864cb900cc60b77acadab5971b0dee54ddd 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -64,6 +64,7 @@ class CMakeExtension(setuptools.Extension): configure_command.append("-GNinja") import pybind11 + pybind11_dir = Path(pybind11.__file__).resolve().parent pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11" configure_command.append(f"-Dpybind11_DIR={pybind11_dir}") @@ -130,6 +131,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): else: # Only during release sdist build. import transformer_engine + search_paths = list(Path(transformer_engine.__path__[0]).iterdir()) del transformer_engine @@ -142,8 +144,9 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): # Figure out stub file path module_name = paddle_ext.name - assert module_name.endswith("_pd_"), \ - "Expected Paddle extension module to end with '_pd_'" + assert module_name.endswith( + "_pd_" + ), "Expected Paddle extension module to end with '_pd_'" stub_name = module_name[:-4] # remove '_pd_' stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py") Path(stub_path).parent.mkdir(exist_ok=True, parents=True) @@ -158,6 +161,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): # Write stub file print(f"Writing Paddle stub for {lib_name} into file {stub_path}") from paddle.utils.cpp_extension.extension_utils import custom_write_stub + custom_write_stub(lib_name, stub_path) # Ensure that binaries are not in global package space. @@ -182,13 +186,14 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): # extra_compile_args is a dict. for ext in self.extensions: if isinstance(ext.extra_compile_args, dict): - for target in ['cxx', 'nvcc']: + for target in ["cxx", "nvcc"]: if target not in ext.extra_compile_args.keys(): ext.extra_compile_args[target] = [] # Define new _compile method that redirects to NVCC for .cu and .cuh files. original_compile_fn = self.compiler._compile - self.compiler.src_extensions += ['.cu', '.cuh'] + self.compiler.src_extensions += [".cu", ".cuh"] + def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None: # Copy before we make any modifications. cflags = copy.deepcopy(extra_postargs) @@ -197,31 +202,31 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): _, nvcc_bin = cuda_path() original_compiler = self.compiler.compiler_so - if os.path.splitext(src)[1] in ['.cu', '.cuh']: - self.compiler.set_executable('compiler_so', str(nvcc_bin)) + if os.path.splitext(src)[1] in [".cu", ".cuh"]: + self.compiler.set_executable("compiler_so", str(nvcc_bin)) if isinstance(cflags, dict): - cflags = cflags['nvcc'] + cflags = cflags["nvcc"] # Add -fPIC if not already specified - if not any('-fPIC' in flag for flag in cflags): - cflags.extend(['--compiler-options', "'-fPIC'"]) + if not any("-fPIC" in flag for flag in cflags): + cflags.extend(["--compiler-options", "'-fPIC'"]) # Forward unknown options - if not any('--forward-unknown-opts' in flag for flag in cflags): - cflags.append('--forward-unknown-opts') + if not any("--forward-unknown-opts" in flag for flag in cflags): + cflags.append("--forward-unknown-opts") elif isinstance(cflags, dict): - cflags = cflags['cxx'] + cflags = cflags["cxx"] # Append -std=c++17 if not already in flags - if not any(flag.startswith('-std=') for flag in cflags): - cflags.append('-std=c++17') + if not any(flag.startswith("-std=") for flag in cflags): + cflags.append("-std=c++17") return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts) finally: # Put the original compiler back in place. - self.compiler.set_executable('compiler_so', original_compiler) + self.compiler.set_executable("compiler_so", original_compiler) self.compiler._compile = _compile_fn diff --git a/build_tools/jax.py b/build_tools/jax.py index 21248640cc2288c6e6d13d0e74a8c9948e3d8a43..496bf056e8270676d69f4ef706d5e9a37bcfaa4b 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -36,8 +36,8 @@ def setup_jax_extension( ] # Compile flags - cxx_flags = [ "-O3" ] - nvcc_flags = [ "-O3" ] + cxx_flags = ["-O3"] + nvcc_flags = ["-O3"] # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension @@ -47,9 +47,9 @@ def setup_jax_extension( def _add_cflags(self, flags: List[str]) -> None: if isinstance(self.extra_compile_args, dict): - cxx_flags = self.extra_compile_args.pop('cxx', []) + cxx_flags = self.extra_compile_args.pop("cxx", []) cxx_flags += flags - self.extra_compile_args['cxx'] = cxx_flags + self.extra_compile_args["cxx"] = cxx_flags else: self.extra_compile_args[:0] = flags @@ -57,8 +57,5 @@ def setup_jax_extension( "transformer_engine_jax", sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], - extra_compile_args={ - "cxx": cxx_flags, - "nvcc": nvcc_flags - }, + extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags}, ) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 655b7dbf26896df608ec900c4165d3fcb7762a18..0b44cfb372f17058bbddbf85adbd0d32b1423471 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -76,11 +76,12 @@ def setup_pytorch_extension( # Libraries -- PyTorch CUDAExtension links to libcudart.so but not to libcuda.so cuda_home, _ = cuda_path() - library_dirs = [ cuda_home / "compat" / "lib" ] - libraries = [ "cuda" ] + library_dirs = [cuda_home / "compat" / "lib"] + libraries = ["cuda"] if os.getenv("UB_MPI_BOOTSTRAP"): - assert os.getenv("MPI_HOME") is not None, \ - "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1" + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1" mpi_home = Path(os.getenv("MPI_HOME")) include_dirs.append(mpi_home / "include") cxx_flags.append("-DUB_MPI_BOOTSTRAP") @@ -95,12 +96,12 @@ def setup_pytorch_extension( return CUDAExtension( name="transformer_engine_torch", - sources=[ str(src) for src in sources ], - include_dirs=[ str(inc) for inc in include_dirs ], + sources=[str(src) for src in sources], + include_dirs=[str(inc) for inc in include_dirs], extra_compile_args={ "cxx": cxx_flags, "nvcc": nvcc_flags, }, - libraries=[ str(lib) for lib in libraries ], - library_dirs=[ str(lib_dir) for lib_dir in library_dirs ], + libraries=[str(lib) for lib in libraries], + library_dirs=[str(lib_dir) for lib_dir in library_dirs], ) diff --git a/build_tools/te_version.py b/build_tools/te_version.py index 9ebc0ccdc470c5ac4d999af8dd42668bddfb48f9..b40fb260146d35bf0c1728ac38ac826b51fa8969 100644 --- a/build_tools/te_version.py +++ b/build_tools/te_version.py @@ -18,11 +18,12 @@ def te_version() -> str: root_path = Path(__file__).resolve().parent with open(root_path / "VERSION.txt", "r") as f: version = f.readline().strip() - if (not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) - and not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0")))): + if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) and not bool( + int(os.getenv("NVTE_RELEASE_BUILD", "0")) + ): try: output = subprocess.run( - ["git", "rev-parse" , "--short", "HEAD"], + ["git", "rev-parse", "--short", "HEAD"], capture_output=True, cwd=root_path, check=True, diff --git a/build_tools/utils.py b/build_tools/utils.py index b601ec137e8cf29e01636f8cda95d57f4a8dac46..036cb1eac5e28e089add81bcac531f4288828e26 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -174,7 +174,7 @@ def cuda_version() -> Tuple[int, ...]: universal_newlines=True, ) match = re.search(r"release\s*([\d.]+)", output.stdout) - version = match.group(1).split('.') + version = match.group(1).split(".") return tuple(int(v) for v in version) @@ -224,9 +224,7 @@ def get_frameworks() -> List[str]: _frameworks = [framework.lower() for framework in _frameworks] for framework in _frameworks: if framework not in supported_frameworks: - raise ValueError( - f"Transformer Engine does not support framework={framework}" - ) + raise ValueError(f"Transformer Engine does not support framework={framework}") return _frameworks @@ -242,8 +240,8 @@ def package_files(directory): def copy_common_headers(te_src, dst): headers = te_src / "common" - for file_path in glob.glob(os.path.join(str(headers), "**", '*.h'), recursive=True): - new_path = os.path.join(dst, file_path[len(str(te_src)) + 1:]) + for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True): + new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :]) Path(new_path).parent.mkdir(exist_ok=True, parents=True) shutil.copy(file_path, new_path) @@ -251,9 +249,10 @@ def copy_common_headers(te_src, dst): def install_and_import(package): """Install a package via pip (if not already installed) and import into globals.""" import importlib + try: importlib.import_module(package) except ImportError: - subprocess.check_call([sys.executable, '-m', 'pip', 'install', package]) + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) finally: globals()[package] = importlib.import_module(package) diff --git a/docs/conf.py b/docs/conf.py index 77b59eb8e0851135c8387c0193d37bcacbd0af43..695546a9bad0b2cb97457f18beb02e6f55c10d03 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,21 +28,26 @@ if current_year == release_year: else: copyright_year = str(release_year) + "-" + str(current_year) -project = u'Transformer Engine' -copyright = u'{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.'.format(copyright_year) -author = u'NVIDIA CORPORATION & AFFILIATES' +project = "Transformer Engine" +copyright = "{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.".format(copyright_year) +author = "NVIDIA CORPORATION & AFFILIATES" git_sha = os.getenv("GIT_SHA") if not git_sha: try: - git_sha = subprocess.check_output(["git", "log", "--pretty=format:'%h'", "-n1"]).decode('ascii').replace("'","").strip() + git_sha = ( + subprocess.check_output(["git", "log", "--pretty=format:'%h'", "-n1"]) + .decode("ascii") + .replace("'", "") + .strip() + ) except: - git_sha = u'0000000' + git_sha = "0000000" git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha -version = str(te_version + u"-" + git_sha) +version = str(te_version + "-" + git_sha) release = te_version # hack: version is used for html creation, so put the version picker @@ -51,58 +56,60 @@ option_on = " selected" option_off = "" release_opt = option_on option_nr = 0 -version = version + """
+version = ( + version + + """
Version select: """.format(option_nr, release_opt) +""".format( + option_nr, release_opt + ) +) # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.ifconfig', - 'nbsphinx', - 'breathe', - 'autoapi.extension', + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.ifconfig", + "nbsphinx", + "breathe", + "autoapi.extension", ] -templates_path = ['_templates'] -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] -source_suffix = '.rst' +source_suffix = ".rst" -master_doc = 'index' - -pygments_style = 'sphinx' +master_doc = "index" +pygments_style = "sphinx" # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] -html_static_path = ['_static'] +html_static_path = ["_static"] html_show_sphinx = False html_css_files = [ - 'css/nvidia_font.css', - 'css/nvidia_footer.css', + "css/nvidia_font.css", + "css/nvidia_footer.css", ] -html_theme_options = { - 'display_version': True, - 'collapse_navigation': False, - 'logo_only': False -} +html_theme_options = {"display_version": True, "collapse_navigation": False, "logo_only": False} -napoleon_custom_sections = [('Parallelism parameters', 'params_style'), - ('Optimization parameters', 'params_style'), - ('Values', 'params_style')] +napoleon_custom_sections = [ + ("Parallelism parameters", "params_style"), + ("Optimization parameters", "params_style"), + ("Values", "params_style"), +] breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")} breathe_default_project = "TransformerEngine" diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py index 33852286d84efc91b007f7ded161d19971b32ab8..cd8ab85ba245ef5173c84dd4ac5235b152d4a867 100644 --- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py +++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py @@ -4,8 +4,8 @@ import os import torch -from typing import Tuple -from tests.pytorch.fused_attn.test_fused_attn import ModelConfig +from typing import Tuple +from tests.pytorch.fused_attn.test_fused_attn import ModelConfig from transformer_engine.pytorch.distributed import _set_cuda_rng_state from transformer_engine.pytorch.attention import DotProductAttention @@ -18,87 +18,105 @@ _cuda_rng_state = torch.cuda.get_rng_state() _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) + def reset_rng_states() -> None: """Revert back to initial RNG state""" torch.set_rng_state(_cpu_rng_state) _set_cuda_rng_state(_cuda_rng_state) + def _run_dot_product_attention( - dtype: torch.dtype, - config: ModelConfig, - qkv_layout: str, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + dtype: torch.dtype, + config: ModelConfig, + qkv_layout: str, +) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run DotProductAttention module with one forward pass and one backward pass""" reset_rng_states() - seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, - dtype=torch.int32, device="cuda") - seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, - dtype=torch.int32, device="cuda") - inp = torch.randn([config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim], - dtype=dtype, device="cuda") - q = inp[:,:,0,:,:] - k = inp[:,:,1,:,:] - v = inp[:,:,2,:,:] + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) + inp = torch.randn( + [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim], + dtype=dtype, + device="cuda", + ) + q = inp[:, :, 0, :, :] + k = inp[:, :, 1, :, :] + v = inp[:, :, 2, :, :] q.requires_grad = True k.requires_grad = True v.requires_grad = True - out_grad = torch.randn([config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim], - dtype=dtype, device="cuda") + out_grad = torch.randn( + [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim], + dtype=dtype, + device="cuda", + ) # Create attention mask / bias attention_mask = None bias = None if config.attn_mask_type == "arbitrary": - attention_mask = torch.randint(-10,10, - [config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to( - dtype=torch.bool, device="cuda") + attention_mask = torch.randint( + -10, + 10, + [config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv], + ).to(dtype=torch.bool, device="cuda") if config.attn_bias_type == "post_scale_bias": # convert mask to bias - attention_mask = torch.randint(-10,10, - [config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to( - dtype=torch.bool, device="cuda") + attention_mask = torch.randint( + -10, + 10, + [config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv], + ).to(dtype=torch.bool, device="cuda") bias = attention_mask.clone() - neginf = -2**50 if dtype == torch.bfloat16 else -2**15 - bias = torch.where(bias==0, 0, neginf).to(dtype=dtype, device='cuda') + neginf = -(2**50) if dtype == torch.bfloat16 else -(2**15) + bias = torch.where(bias == 0, 0, neginf).to(dtype=dtype, device="cuda") bias.requires_grad = False attention_mask = None - block = ( - DotProductAttention( - config.num_heads, - config.head_dim, - num_gqa_groups=config.num_gqa_groups, - qkv_format='bshd', - attention_dropout=config.dropout_p, - sequence_parallel=False, - tp_size=1, - get_rng_state_tracker=None, - tp_group=None, - layer_number=1, - ).to(dtype=dtype, device="cuda") - ) + block = DotProductAttention( + config.num_heads, + config.head_dim, + num_gqa_groups=config.num_gqa_groups, + qkv_format="bshd", + attention_dropout=config.dropout_p, + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=None, + tp_group=None, + layer_number=1, + ).to(dtype=dtype, device="cuda") # Run a forward and backward pass out = None if config.attn_mask_type == "arbitrary": - out = block(q, k, v, - attention_mask=attention_mask, # attention_mask - qkv_format='bshd', - attn_mask_type=config.attn_mask_type, # 'arbitrary' - core_attention_bias_type=config.attn_bias_type, # 'no_bias' - core_attention_bias=bias, # None - ) + out = block( + q, + k, + v, + attention_mask=attention_mask, # attention_mask + qkv_format="bshd", + attn_mask_type=config.attn_mask_type, # 'arbitrary' + core_attention_bias_type=config.attn_bias_type, # 'no_bias' + core_attention_bias=bias, # None + ) out.backward(out_grad) if config.attn_bias_type == "post_scale_bias": - out = block(q, k, v, - attention_mask=attention_mask, # None - qkv_format='bshd', - attn_mask_type=config.attn_mask_type, # no_mask - core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias' - core_attention_bias=bias, # bias - ) + out = block( + q, + k, + v, + attention_mask=attention_mask, # None + qkv_format="bshd", + attn_mask_type=config.attn_mask_type, # no_mask + core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias' + core_attention_bias=bias, # bias + ) out.backward(out_grad) return out, (q.grad, k.grad, v.grad) @@ -107,19 +125,19 @@ def _run_dot_product_attention( dtype = torch.bfloat16 model_configs = { # test: b, h, hg, d, sq, skv, p, mask, bias - "test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"), - "test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), + "test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"), + "test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), } -print('Run with post_scale_bias:') +print("Run with post_scale_bias:") config = model_configs["test_bias"] -fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd') +fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") -print('Run with arbitrary mask:') +print("Run with arbitrary mask:") config = model_configs["test_mask"] -unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd') +unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2) for i in range(3): torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2) -print('Test passed!') +print("Test passed!") diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py index 5eac38f99bd1a92b4b1f82896d1689f8cb31a4a5..2ed73034179e3603a88100beaeb94cfbade42908 100644 --- a/docs/examples/attention/example_attention.py +++ b/docs/examples/attention/example_attention.py @@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import ( _is_flash_attention_supported, _is_fused_attention_supported, _is_unfused_attention_supported, - _run_dot_product_attention + _run_dot_product_attention, ) # data type @@ -26,7 +26,7 @@ ckpt_attn = False # workspace optimization path for cuDNN attention workspace_opt = True # QKV memory layout -qkv_layout = 'bshd_bshd_bshd' +qkv_layout = "bshd_bshd_bshd" # sliding window attention swa = False # padding between sequences for qkv_format=thd @@ -36,12 +36,13 @@ is_training = True model_configs = { # test: b, h, hg, d, sq, skv, p, mask, bias - "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq - "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask - "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias - "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA + "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq + "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask + "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias + "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA } + def example_attention(model, fused_attn_supported, flash_attn_supported): config = model_configs[model] if dtype == torch.bfloat16: @@ -51,40 +52,58 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): if fused_attn_supported: print() - print('Run cuDNN attention...') + print("Run cuDNN attention...") fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( - dtype, config, "FusedAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) if flash_attn_supported: print() - print('Run flash-attention...') + print("Run flash-attention...") flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( - dtype, config, "FlashAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FlashAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) if fused_attn_supported and flash_attn_supported: torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) - for i,_ in enumerate(flash_attn_bwd): + for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) print() - print('Test passed.') + print("Test passed.") + def main(): - models = ['test_0'] + models = ["test_0"] for model in models: config = model_configs[model] fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( - config, dtype, qkv_layout=qkv_layout, + config, + dtype, + qkv_layout=qkv_layout, ) fused_attn_supported = fused_attn_supported and not swa flash_attn_supported = _is_flash_attention_supported(config) example_attention(model, fused_attn_supported, flash_attn_supported) + if __name__ == "__main__": main() diff --git a/docs/examples/quickstart_utils.py b/docs/examples/quickstart_utils.py index 9532dd27fabf16b4b0d85de5c4f8ca1454b3ffec..0582efd52e629613acdd89778ccbbb0a1ac6cdd4 100644 --- a/docs/examples/quickstart_utils.py +++ b/docs/examples/quickstart_utils.py @@ -8,14 +8,15 @@ import torch import transformer_engine.pytorch as te from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type + def speedometer( - module: torch.nn.Module, - input: torch.Tensor, - output_grad: torch.Tensor, - forward_kwargs: dict = {}, - fp8_autocast_kwargs: Optional[dict] = None, - timing_iters: int = 50, - warmup_iters: int = 50, + module: torch.nn.Module, + input: torch.Tensor, + output_grad: torch.Tensor, + forward_kwargs: dict = {}, + fp8_autocast_kwargs: Optional[dict] = None, + timing_iters: int = 50, + warmup_iters: int = 50, ) -> None: """Measure average run time for a PyTorch module @@ -24,7 +25,7 @@ def speedometer( start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) if fp8_autocast_kwargs is None: - fp8_autocast_kwargs = { "enabled": False } + fp8_autocast_kwargs = {"enabled": False} # Warmup runs torch.cuda.synchronize() @@ -51,11 +52,12 @@ class DotProductAttention(torch.nn.Module): Built with plain PyTorch modules. """ + def __init__( - self, - num_attention_heads: int, - kv_channels: int, - attention_dropout: float, + self, + num_attention_heads: int, + kv_channels: int, + attention_dropout: float, ) -> None: super().__init__() self.projection_size = kv_channels * num_attention_heads @@ -63,21 +65,17 @@ class DotProductAttention(torch.nn.Module): self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.dropout = torch.nn.Dropout(attention_dropout) - def masked_softmax( - self, - inp: torch.Tensor, - mask: Optional[torch.Tensor] - ) -> torch.Tensor: + def masked_softmax(self, inp: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: if mask is not None: inp.masked_fill_(mask, -10000.0) return torch.nn.Softmax(dim=-1)(inp) def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: b = query.size(1) np = query.size(2) @@ -90,7 +88,9 @@ class DotProductAttention(torch.nn.Module): # [sk, b, np, hn] -> [sk, b * np, hn] key = key.view(sk, b * np, -1) - bmm1 = torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor + bmm1 = ( + torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor + ) # change view to [b, np, sq, sk] attention_scores = bmm1.view(b, np, sq, sk) @@ -126,10 +126,11 @@ class BasicMLP(torch.nn.Module): Built with plain PyTorch modules. """ + def __init__( - self, - hidden_size: int, - ffn_hidden_size: int, + self, + hidden_size: int, + ffn_hidden_size: int, ) -> None: super().__init__() self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True) @@ -137,7 +138,7 @@ class BasicMLP(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear1(x) - x = torch.nn.functional.gelu(x, approximate='tanh') + x = torch.nn.functional.gelu(x, approximate="tanh") x = self.linear2(x) return x @@ -148,7 +149,7 @@ def share_parameters_with_basic_te_model(te_model, basic_model): Parameter values are copied from pure PyTorch implementation. """ - te_model.ln1.weight= basic_model.ln1.weight + te_model.ln1.weight = basic_model.ln1.weight te_model.ln1.bias = basic_model.ln1.bias te_model.qkv_projection.weight = basic_model.qkv_projection.weight te_model.qkv_projection.bias = basic_model.qkv_projection.bias @@ -202,14 +203,15 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model): te_model.layernorm_mlp.fc2_bias = basic_model.mlp.linear2.bias -def cast_to_representable(inp, scale = 1., fp8_format='e4m3'): +def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"): import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType - fp8_type = tex.DType.kFloat8E4M3 if fp8_format == 'e4m3' else tex.DType.kFloat8E5M2 + + fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2 input_type = TE_DType[inp.dtype] meta = tex.FP8TensorMeta() - meta.scale = torch.ones(1,dtype=torch.float32, device="cuda") * scale + meta.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda") ret = texcpp.cast_to_fp8(inp, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index 307507ad1d341b46bcec2a078d882d338e748a5c..cb384aa10c41cd747e298e5b16f1d51b788516bb 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -15,11 +15,17 @@ from transformer_engine.pytorch.attention import RotaryPositionEmbedding from transformer_engine.pytorch.fp8 import fp8_model_init import transformers -from transformers.models.llama.modeling_llama import LlamaModel, LlamaForCausalLM, LlamaRMSNorm, LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaModel, + LlamaForCausalLM, + LlamaRMSNorm, + LlamaConfig, +) from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model from transformers.utils import WEIGHTS_INDEX_NAME from transformers.utils.hub import get_checkpoint_shard_files + @contextmanager def replace_decoder(te_decoder_cls): """ @@ -43,6 +49,7 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer): args: positional args (for compatibility with `LlamaDecoderLayer`) kwargs: keyword args (for compatibility with `LlamaDecoderLayer`) """ + def __init__(self, config, *args, **kwargs): super().__init__( hidden_size=config.hidden_size, @@ -56,22 +63,22 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer): normalization="RMSNorm", activation="swiglu", attn_input_format="bshd", - num_gqa_groups=config.num_key_value_heads + num_gqa_groups=config.num_key_value_heads, ) - te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads) + te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() - def forward(self, - hidden_states, - *args, - attention_mask, - **kwargs): + def forward(self, hidden_states, *args, attention_mask, **kwargs): """ Custom forward to make sure we only pass relevant arguments to the forward pass of the `TransformerLayer`. Also, make sure the output format matches the output of the HF's `LlamaDecoderLayer`. """ - return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),) + return ( + super().forward( + hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb + ), + ) class TELlamaForCausalLM: @@ -95,21 +102,29 @@ class TELlamaForCausalLM: Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - vanilla_model = cls(config).to(kwargs['torch_dtype']) + vanilla_model = cls(config).to(kwargs["torch_dtype"]) is_local = os.path.isdir(pretrained_model_name_or_path) subfolder = "" variant = None if os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant)) - ): + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant("model.safetensors.index.json", variant), + ) + ): # Load from a sharded PyTorch checkpoint archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant) + pretrained_model_name_or_path, + subfolder, + _add_variant("model.safetensors.index.json", variant), ) is_sharded = True elif os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) - ): + os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + ): # Load from a sharded PyTorch checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) @@ -118,10 +133,9 @@ class TELlamaForCausalLM: else: raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path, - archive_file, + pretrained_model_name_or_path, + archive_file, ) # If the checkpoint is not sharded, it's a trivial sharding case @@ -142,48 +156,63 @@ class TELlamaForCausalLM: return vanilla_model + def replace_params(hf_state_dict, te_state_dict, config): # collect all layer prefixes to update all_layer_prefixes = set() for param_key in hf_state_dict.keys(): - layer_prefix_pat = 'model.layers.\d+.' + layer_prefix_pat = "model.layers.\d+." m = re.match(layer_prefix_pat, param_key) if m is not None: all_layer_prefixes.add(m.group()) - - for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the # copy if the corresponding layer doesn't exist in HF model - if layer_prefix + 'input_layernorm.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:] + if layer_prefix + "input_layernorm.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[ + : + ] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:] + + if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = ( + hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:] + ) - if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:] + if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = ( + hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:] + ) - if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:] + if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = ( + hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:] + ) - if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:] + if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[ + layer_prefix + "self_attn.o_proj.weight" + ].data[:] - if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:] + if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict: + te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[ + layer_prefix + "post_attention_layernorm.weight" + ].data[:] - if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] - # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to # load them separately. - if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \ - hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data - - if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \ - hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data - - if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] - return all_layer_prefixes \ No newline at end of file + if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[ + : config.intermediate_size + ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data + + if layer_prefix + "mlp.up_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[ + config.intermediate_size : + ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data + + if layer_prefix + "mlp.down_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[ + layer_prefix + "mlp.down_proj.weight" + ].data[:] + return all_layer_prefixes diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index 71d2aa2e2ef4deedfcb7d2bcc797283ffd80887b..b6b3683d4c397008ad7273818fe62af7c4df153a 100644 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -10,29 +10,36 @@ import torch from torch.optim import AdamW from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, AutoConfig +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + get_linear_schedule_with_warmup, + AutoConfig, +) from transformers import DataCollatorForLanguageModeling from datasets import load_dataset from accelerate import Accelerator from accelerate.utils.dataclasses import FP8RecipeKwargs + class HyperParameters: def __init__(self): self.mixed_precision = "bf16" - #self.model_name = "" # <== Add model weight location here + # self.model_name = "" # <== Add model weight location here self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 self.batch_size = 8 self.max_seq_length = 256 self.gradient_accumulation_steps = 1 - self.num_warmup_steps=5 - self.num_training_steps=10 - + self.num_warmup_steps = 5 + self.num_training_steps = 10 + hyperparams = HyperParameters() -def get_dataloaders(accelerator:Accelerator, hyperparams): + +def get_dataloaders(accelerator: Accelerator, hyperparams): dataset = load_dataset(hyperparams.dataset_name, split="train") tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) if getattr(tokenizer, "pad_token", None) is None: @@ -45,16 +52,12 @@ def get_dataloaders(accelerator:Accelerator, hyperparams): padding=False, max_length=hyperparams.max_seq_length, return_overflowing_tokens=False, - return_length=False + return_length=False, ) return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} with accelerator.main_process_first(): - dataset = dataset.map( - tokenize, - batched=True, - remove_columns=dataset.column_names - ) + dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) # Simply pad to the multiple of 16 for both FP8 and BF16 precision pad_to_multiple_of = 16 @@ -72,6 +75,7 @@ def get_dataloaders(accelerator:Accelerator, hyperparams): train_dataloader = DataLoader(dataset, **dataloader_params) return train_dataloader + def init_baseline_model(hyperparams): # Init the model config = AutoConfig.from_pretrained(hyperparams.model_name) @@ -84,42 +88,47 @@ def init_baseline_model(hyperparams): ) model = model.cuda() # Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison - model.config.use_cache=False + model.config.use_cache = False return model + def init_te_llama_model(hyperparams): # Init the model from te_llama import TELlamaForCausalLM + config = AutoConfig.from_pretrained(hyperparams.model_name) config._attn_implementation = "flash_attention_2" model = TELlamaForCausalLM.from_pretrained_local( - hyperparams.model_name, - config=config, - torch_dtype=torch.bfloat16, + hyperparams.model_name, + config=config, + torch_dtype=torch.bfloat16, ) model = model.cuda() # Needed for the cases when using TELlamaForCausalLM - model.config.use_cache=False + model.config.use_cache = False return model + def wrap_with_accelerator(model, hyperparams): # Create FP8 kwarg handler if required - fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None + fp8_kwarg_handler = ( + [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None + ) # Init HF accelerator that's used for training accelerator = Accelerator( log_with="wandb", gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, mixed_precision=hyperparams.mixed_precision, - kwargs_handlers=fp8_kwarg_handler + kwargs_handlers=fp8_kwarg_handler, ) - #accelerator.print(f'State: {accelerator.state}') + # accelerator.print(f'State: {accelerator.state}') train_dataloader = get_dataloaders(accelerator, hyperparams) # Wrap model, optimizer/scheduler, dataloaders in accelerate - optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True) + optimizer = AdamW(params=model.parameters(), lr=hyperparams.learning_rate, fused=True) lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, @@ -131,6 +140,7 @@ def wrap_with_accelerator(model, hyperparams): return accelerator, model, optimizer, train_dataloader, lr_scheduler + def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler): model.train() total_loss = 0 @@ -170,7 +180,11 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, end.record() accelerator.end_training() - print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds") + print( + f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step:" + f" {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds" + ) + def restart_jupyter_notebook(): # Try restarting the Jupyter kernel @@ -179,18 +193,23 @@ def restart_jupyter_notebook(): # Check whether the device memory has been flushed if torch.cuda.memory_allocated() != 0: import warnings + warnings.warn("The device memory hasn't been flushed, trying with a second method!") # Try restarting the Jupyter kernel another way # Restart the kernel from IPython.core.display import HTML + HTML("") if torch.cuda.memory_allocated() != 0: - print("The device memory hasn't been flushed, try manually restarting the Jupyter kernel!") + print( + "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!" + ) # Suppress the warnings if not sys.warnoptions: import warnings + warnings.simplefilter("ignore") torch.set_warn_always(False) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index b02b3eddd90b2eb2de4ea40d8d46e7ec68a4172a..716d543d5b9e7f68d3cc3f0745376d24545d05b1 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -22,18 +22,19 @@ from jax.experimental.pjit import pjit import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -DEVICE_DP_AXIS = 'data' -DEVICE_TP_AXIS = 'model' -NAMED_BROADCAST_AXIS = 'my_broadcast_axis' -NAMED_TP_AXIS = 'my_tp_axis' -PARAMS_KEY = 'params' -PARAMS_AXES_KEY = PARAMS_KEY + '_axes' -DROPOUT_KEY = 'dropout' -INPUT_KEY = 'input_rng' +DEVICE_DP_AXIS = "data" +DEVICE_TP_AXIS = "model" +NAMED_BROADCAST_AXIS = "my_broadcast_axis" +NAMED_TP_AXIS = "my_tp_axis" +PARAMS_KEY = "params" +PARAMS_AXES_KEY = PARAMS_KEY + "_axes" +DROPOUT_KEY = "dropout" +INPUT_KEY = "input_rng" class Net(nn.Module): """NLP Encoder""" + num_embed: int enable_seq_paral: bool @@ -41,36 +42,43 @@ class Net(nn.Module): def __call__(self, x, mask, disable_dropout=False): x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) - te_Encoder = partial(te_flax.TransformerLayer, - hidden_size=256, - mlp_hidden_size=1024, - num_attention_heads=8, - hidden_dropout=0.1, - attention_dropout=0.1, - dropout_rng_name=DROPOUT_KEY, - layer_type=te_flax.TransformerLayerType.ENCODER, - self_attn_mask_type='padding', - enable_relative_embedding=False, - enable_sequence_parallel=self.enable_seq_paral, - dtype=jnp.bfloat16) + te_Encoder = partial( + te_flax.TransformerLayer, + hidden_size=256, + mlp_hidden_size=1024, + num_attention_heads=8, + hidden_dropout=0.1, + attention_dropout=0.1, + dropout_rng_name=DROPOUT_KEY, + layer_type=te_flax.TransformerLayerType.ENCODER, + self_attn_mask_type="padding", + enable_relative_embedding=False, + enable_sequence_parallel=self.enable_seq_paral, + dtype=jnp.bfloat16, + ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) if self.enable_seq_paral: # Trigger all-gather to collect a complete tensor alone seqence on each device. - x = jax.lax.with_sharding_constraint(x, - jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)) - - x = te_flax.DenseGeneral(features=256, - kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), - bias_axes=(NAMED_TP_AXIS,), - dtype=jnp.bfloat16)(x) - - x = te_flax.DenseGeneral(features=256, - kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), - bias_axes=(NAMED_BROADCAST_AXIS,), - dtype=jnp.bfloat16)(x) + x = jax.lax.with_sharding_constraint( + x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) + ) + + x = te_flax.DenseGeneral( + features=256, + kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), + bias_axes=(NAMED_TP_AXIS,), + dtype=jnp.bfloat16, + )(x) + + x = te_flax.DenseGeneral( + features=256, + kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), + bias_axes=(NAMED_BROADCAST_AXIS,), + dtype=jnp.bfloat16, + )(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) return x @@ -98,20 +106,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): """Train for a single epoch.""" - train_ds_size = len(train_ds['sentence']) + train_ds_size = len(train_ds["sentence"]) steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) - perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch perms = perms.reshape((steps_per_epoch, batch_size)) epoch_loss = [] epoch_accuracy = [] for perm in perms: - batch_inputs = train_ds['sentence'][perm, ...] - batch_masks = train_ds['mask'][perm, ...] - batch_labels = train_ds['label'][perm, ...] - state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, - batch_labels, var_collect, rngs) + batch_inputs = train_ds["sentence"][perm, ...] + batch_masks = train_ds["mask"][perm, ...] + batch_labels = train_ds["label"][perm, ...] + state, loss, accuracy, var_collect = train_fn( + state, batch_inputs, batch_masks, batch_labels, var_collect, rngs + ) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -137,7 +146,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def eval_model(state, test_ds, batch_size, var_collect, eval_fn): """Evaluation loop.""" - test_ds_size = len(test_ds['sentence']) + test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size valid_size = num_steps * batch_size all_loss = [] @@ -145,9 +154,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): for batch_start in range(0, valid_size, batch_size): batch_end = batch_start + batch_size - batch_inputs = test_ds['sentence'][batch_start:batch_end] - batch_masks = test_ds['mask'][batch_start:batch_end] - batch_labels = test_ds['label'][batch_start:batch_end] + batch_inputs = test_ds["sentence"][batch_start:batch_end] + batch_masks = test_ds["mask"][batch_start:batch_end] + batch_labels = test_ds["label"][batch_start:batch_end] loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) all_loss.append(loss) all_accuracy.append(accuracy) @@ -159,12 +168,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download('punkt') - dataset_size = len(dataset['sentence']) + nltk.download("punkt") + dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) - for j, sentence in enumerate(dataset['sentence']): + for j, sentence in enumerate(dataset["sentence"]): tokens = nltk.word_tokenize(sentence) tensor = output[j] @@ -184,9 +193,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): mask_2d[:seq_len, :seq_len] = 0 new_dataset = { - 'sentence': output, - 'label': dataset['label'].astype(np.float32), - 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) + "sentence": output, + "label": dataset["label"].astype(np.float32), + "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)), } return new_dataset, vocab, word_id @@ -196,12 +205,12 @@ def get_datasets(max_seq_len): vocab = {} word_id = 0 - train_ds = load_dataset('glue', 'cola', split='train') - train_ds.set_format(type='np') + train_ds = load_dataset("glue", "cola", split="train") + train_ds.set_format(type="np") train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) - test_ds = load_dataset('glue', 'cola', split='validation') - test_ds.set_format(type='np') + test_ds = load_dataset("glue", "cola", split="validation") + test_ds.set_format(type="np") test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) return train_ds, test_ds, word_id @@ -210,7 +219,8 @@ def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} assert "fp8_" in str( - jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) + jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs) + ) def get_params_pspec(sharding_rules, abs_var_collect): @@ -255,8 +265,9 @@ def train_and_evaluate(args): num_gpu_tp = 1 assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}" - assert args.test_batch_size % num_gpu_dp == 0, \ - f"Test batch size needs to be multiple of {num_gpu_dp}" + assert ( + args.test_batch_size % num_gpu_dp == 0 + ), f"Test batch size needs to be multiple of {num_gpu_dp}" device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): @@ -270,9 +281,9 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast(args.use_fp8, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, - None)): + with te.fp8_autocast( + args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None) + ): encoder = Net(num_embed, args.enable_sp) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) @@ -285,18 +296,21 @@ def train_and_evaluate(args): masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) in_shardings = (None, inputs_pspec, masks_pspec) - out_shardings = {key: params_pspec if key is PARAMS_KEY else None \ - for key in abs_var_collect} + out_shardings = { + key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect + } pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) var_collect = pjit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) - state = train_state.TrainState.create(apply_fn=encoder.apply, - params=params, - tx=optimizer) + state = train_state.TrainState.create( + apply_fn=encoder.apply, params=params, tx=optimizer + ) state_pspec = get_state_pspec(state, params_pspec) - labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,) + labels_pspec = jax.sharding.PartitionSpec( + DEVICE_DP_AXIS, + ) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) out_shardings = (state_pspec, None, None, None) @@ -323,16 +337,20 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step) + state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step + ) - test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, - var_collect, pjit_eval_step) + test_loss, test_accuracy = eval_model( + state, test_ds, args.test_batch_size, var_collect, pjit_eval_step + ) - print(f"Epoch: {epoch:>2} " - f"Train Loss: {train_loss:.6f} " - f"Train Accuracy: {train_accuracy:.6f} " - f"Test Loss: {test_loss:.6f} " - f"Test Accuracy: {test_accuracy:.6f} ") + print( + f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} " + ) return [train_loss, train_accuracy, test_loss, test_accuracy] @@ -382,14 +400,15 @@ def encoder_parser(args): help="quickly check a single pass", ) parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") - parser.add_argument("--use-fp8", - action="store_true", - default=False, - help="Use FP8 for inference and training without recalibration") - parser.add_argument("--enable-sp", - action="store_true", - default=False, - help="Enable sequence parallelism.") + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration", + ) + parser.add_argument( + "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism." + ) return parser.parse_args(args) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index f7c1779ca15f85885f12e8f201f0c9a76daa3490..c6223ed5bb782d4c7327d840b8b30d290d13ed0d 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -22,32 +22,35 @@ from jax.experimental.pjit import pjit import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -DEVICE_DP_AXIS = 'data' -PARAMS_KEY = 'params' -PARAMS_AXES_KEY = PARAMS_KEY + '_axes' -DROPOUT_KEY = 'dropout' -INPUT_KEY = 'input_rng' +DEVICE_DP_AXIS = "data" +PARAMS_KEY = "params" +PARAMS_AXES_KEY = PARAMS_KEY + "_axes" +DROPOUT_KEY = "dropout" +INPUT_KEY = "input_rng" class Net(nn.Module): """NLP Encoder""" + num_embed: int @nn.compact def __call__(self, x, mask, disable_dropout=False): x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) - te_Encoder = partial(te_flax.TransformerLayer, - hidden_size=256, - mlp_hidden_size=1024, - num_attention_heads=8, - hidden_dropout=0.1, - attention_dropout=0.1, - dropout_rng_name=DROPOUT_KEY, - layer_type=te_flax.TransformerLayerType.ENCODER, - self_attn_mask_type='padding', - enable_relative_embedding=False, - dtype=jnp.bfloat16) + te_Encoder = partial( + te_flax.TransformerLayer, + hidden_size=256, + mlp_hidden_size=1024, + num_attention_heads=8, + hidden_dropout=0.1, + attention_dropout=0.1, + dropout_rng_name=DROPOUT_KEY, + layer_type=te_flax.TransformerLayerType.ENCODER, + self_attn_mask_type="padding", + enable_relative_embedding=False, + dtype=jnp.bfloat16, + ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) @@ -82,20 +85,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): """Train for a single epoch.""" - train_ds_size = len(train_ds['sentence']) + train_ds_size = len(train_ds["sentence"]) steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) - perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch perms = perms.reshape((steps_per_epoch, batch_size)) epoch_loss = [] epoch_accuracy = [] for perm in perms: - batch_inputs = train_ds['sentence'][perm, ...] - batch_masks = train_ds['mask'][perm, ...] - batch_labels = train_ds['label'][perm, ...] - state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, - batch_labels, var_collect, rngs) + batch_inputs = train_ds["sentence"][perm, ...] + batch_masks = train_ds["mask"][perm, ...] + batch_labels = train_ds["label"][perm, ...] + state, loss, accuracy, var_collect = train_fn( + state, batch_inputs, batch_masks, batch_labels, var_collect, rngs + ) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -121,7 +125,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def eval_model(state, test_ds, batch_size, var_collect, eval_fn): """Evaluation loop.""" - test_ds_size = len(test_ds['sentence']) + test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size valid_size = num_steps * batch_size all_loss = [] @@ -129,9 +133,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): for batch_start in range(0, valid_size, batch_size): batch_end = batch_start + batch_size - batch_inputs = test_ds['sentence'][batch_start:batch_end] - batch_masks = test_ds['mask'][batch_start:batch_end] - batch_labels = test_ds['label'][batch_start:batch_end] + batch_inputs = test_ds["sentence"][batch_start:batch_end] + batch_masks = test_ds["mask"][batch_start:batch_end] + batch_labels = test_ds["label"][batch_start:batch_end] loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) all_loss.append(loss) all_accuracy.append(accuracy) @@ -143,12 +147,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download('punkt') - dataset_size = len(dataset['sentence']) + nltk.download("punkt") + dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) - for j, sentence in enumerate(dataset['sentence']): + for j, sentence in enumerate(dataset["sentence"]): tokens = nltk.word_tokenize(sentence) tensor = output[j] @@ -168,9 +172,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): mask_2d[:seq_len, :seq_len] = 0 new_dataset = { - 'sentence': output, - 'label': dataset['label'].astype(np.float32), - 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) + "sentence": output, + "label": dataset["label"].astype(np.float32), + "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)), } return new_dataset, vocab, word_id @@ -180,12 +184,12 @@ def get_datasets(max_seq_len): vocab = {} word_id = 0 - train_ds = load_dataset('glue', 'cola', split='train') - train_ds.set_format(type='np') + train_ds = load_dataset("glue", "cola", split="train") + train_ds.set_format(type="np") train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) - test_ds = load_dataset('glue', 'cola', split='validation') - test_ds.set_format(type='np') + test_ds = load_dataset("glue", "cola", split="validation") + test_ds.set_format(type="np") test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) return train_ds, test_ds, word_id @@ -194,7 +198,8 @@ def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} assert "fp8_" in str( - jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) + jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs) + ) def get_params_pspec(sharding_rules, abs_var_collect): @@ -232,8 +237,7 @@ def train_and_evaluate(args): num_gpu = jax.local_device_count() assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}" - assert args.test_batch_size % num_gpu == 0, \ - f"Test batch size needs to be multiple of {num_gpu}" + assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}" device_mesh = mesh_utils.create_device_mesh((num_gpu,)) with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)): @@ -247,8 +251,9 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast(args.use_fp8, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)): + with te.fp8_autocast( + args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None) + ): encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) @@ -260,18 +265,21 @@ def train_and_evaluate(args): masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) in_shardings = (None, inputs_pspec, masks_pspec) - out_shardings = {key: params_pspec if key is PARAMS_KEY else None \ - for key in abs_var_collect} + out_shardings = { + key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect + } pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) var_collect = pjit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) - state = train_state.TrainState.create(apply_fn=encoder.apply, - params=params, - tx=optimizer) + state = train_state.TrainState.create( + apply_fn=encoder.apply, params=params, tx=optimizer + ) state_pspec = get_state_pspec(state, params_pspec) - labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,) + labels_pspec = jax.sharding.PartitionSpec( + DEVICE_DP_AXIS, + ) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) out_shardings = (state_pspec, None, None, None) @@ -298,16 +306,20 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step) + state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step + ) - test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, - var_collect, pjit_eval_step) + test_loss, test_accuracy = eval_model( + state, test_ds, args.test_batch_size, var_collect, pjit_eval_step + ) - print(f"Epoch: {epoch:>2} " - f"Train Loss: {train_loss:.6f} " - f"Train Accuracy: {train_accuracy:.6f} " - f"Test Loss: {test_loss:.6f} " - f"Test Accuracy: {test_accuracy:.6f} ") + print( + f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} " + ) return [train_loss, train_accuracy, test_loss, test_accuracy] @@ -357,10 +369,12 @@ def encoder_parser(args): help="quickly check a single pass", ) parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") - parser.add_argument("--use-fp8", - action="store_true", - default=False, - help="Use FP8 for inference and training without recalibration") + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration", + ) return parser.parse_args(args) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 1330d5c84500ad903802fe7fb3d99f0d69edbfe0..c9620aa2be65ccf77b697164e70831d1d5bf3b9e 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -24,49 +24,56 @@ from jax.experimental.pjit import pjit import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' -DEVICE_DP_AXIS = 'data' -DEVICE_TP_AXIS = 'model' -NAMED_BROADCAST_AXIS = 'my_broadcast_axis' -NAMED_TP_AXIS = 'my_tp_axis' -PARAMS_KEY = 'params' -PARAMS_AXES_KEY = PARAMS_KEY + '_axes' -DROPOUT_KEY = 'dropout' -INPUT_KEY = 'input_rng' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +DEVICE_DP_AXIS = "data" +DEVICE_TP_AXIS = "model" +NAMED_BROADCAST_AXIS = "my_broadcast_axis" +NAMED_TP_AXIS = "my_tp_axis" +PARAMS_KEY = "params" +PARAMS_AXES_KEY = PARAMS_KEY + "_axes" +DROPOUT_KEY = "dropout" +INPUT_KEY = "input_rng" class Net(nn.Module): """NLP Encoder""" + num_embed: int @nn.compact def __call__(self, x, mask, disable_dropout=False): x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) - te_Encoder = partial(te_flax.TransformerLayer, - hidden_size=256, - mlp_hidden_size=1024, - num_attention_heads=8, - hidden_dropout=0.1, - attention_dropout=0.1, - dropout_rng_name=DROPOUT_KEY, - layer_type=te_flax.TransformerLayerType.ENCODER, - self_attn_mask_type='padding', - enable_relative_embedding=False, - dtype=jnp.bfloat16) + te_Encoder = partial( + te_flax.TransformerLayer, + hidden_size=256, + mlp_hidden_size=1024, + num_attention_heads=8, + hidden_dropout=0.1, + attention_dropout=0.1, + dropout_rng_name=DROPOUT_KEY, + layer_type=te_flax.TransformerLayerType.ENCODER, + self_attn_mask_type="padding", + enable_relative_embedding=False, + dtype=jnp.bfloat16, + ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256, - kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), - bias_axes=(NAMED_TP_AXIS,), - dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral( + features=256, + kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), + bias_axes=(NAMED_TP_AXIS,), + dtype=jnp.bfloat16, + )(x) - x = te_flax.DenseGeneral(features=256, - kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), - bias_axes=(NAMED_BROADCAST_AXIS,), - dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral( + features=256, + kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), + bias_axes=(NAMED_BROADCAST_AXIS,), + dtype=jnp.bfloat16, + )(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) return x @@ -90,8 +97,9 @@ def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False (dp_size, tp_size) = mesh.device_ids.shape valid_input_size, global_batch_size, num_steps, tp_group_id = valid_shard_size( - total_input_size, batch_size, dp_size, tp_size) - inputs = inputs[:valid_input_size] # skip incomplete batch + total_input_size, batch_size, dp_size, tp_size + ) + inputs = inputs[:valid_input_size] # skip incomplete batch single_input_shape = inputs.shape[1:] global_input_shape = (global_batch_size, *single_input_shape) @@ -124,25 +132,39 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): return state, loss, accuracy, var_collect -def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn, mesh, inputs_pspec, - masks_pspec, labels_pspec): +def train_epoch( + state, + train_ds, + batch_size, + rngs, + var_collect, + train_fn, + mesh, + inputs_pspec, + masks_pspec, + labels_pspec, +): """Train for a single epoch.""" - total_batch_size = len(train_ds['sentence']) + total_batch_size = len(train_ds["sentence"]) (dp_size, tp_size) = mesh.device_ids.shape - valid_size, _, num_steps, tp_group_id = valid_shard_size(total_batch_size, batch_size, dp_size, - tp_size) + valid_size, _, num_steps, tp_group_id = valid_shard_size( + total_batch_size, batch_size, dp_size, tp_size + ) perms = jax.random.permutation(rngs[INPUT_KEY], valid_size) perms = perms.reshape(dp_size, num_steps, batch_size) perms = perms[tp_group_id] global_input_shape, input_named_sharding, sentence = shard_array_wrapper( - train_ds['sentence'], batch_size, mesh, inputs_pspec) + train_ds["sentence"], batch_size, mesh, inputs_pspec + ) global_mask_shape, mask_named_sharding, mask = shard_array_wrapper( - train_ds['mask'], batch_size, mesh, masks_pspec) + train_ds["mask"], batch_size, mesh, masks_pspec + ) global_label_shape, label_named_sharding, label = shard_array_wrapper( - train_ds['label'], batch_size, mesh, labels_pspec) + train_ds["label"], batch_size, mesh, labels_pspec + ) epoch_loss = [] epoch_accuracy = [] @@ -152,15 +174,19 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn, mesh, batch_mask = mask[perm, ...] batch_label = label[perm, ...] - shard_input = jax.make_array_from_single_device_arrays(global_input_shape, - input_named_sharding, [batch_input]) - shard_mask = jax.make_array_from_single_device_arrays(global_mask_shape, - mask_named_sharding, [batch_mask]) - shard_label = jax.make_array_from_single_device_arrays(global_label_shape, - label_named_sharding, [batch_label]) - - state, loss, accuracy, var_collect = train_fn(state, shard_input, shard_mask, shard_label, - var_collect, rngs) + shard_input = jax.make_array_from_single_device_arrays( + global_input_shape, input_named_sharding, [batch_input] + ) + shard_mask = jax.make_array_from_single_device_arrays( + global_mask_shape, mask_named_sharding, [batch_mask] + ) + shard_label = jax.make_array_from_single_device_arrays( + global_label_shape, label_named_sharding, [batch_label] + ) + + state, loss, accuracy, var_collect = train_fn( + state, shard_input, shard_mask, shard_label, var_collect, rngs + ) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -184,36 +210,34 @@ def eval_step(state, inputs, masks, labels, var_collect): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, - labels_pspec): +def eval_model( + state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, labels_pspec +): """Evaluation loop.""" - global_input_shape, input_named_sharding, sentence = shard_array_wrapper(test_ds['sentence'], - batch_size, - mesh, - inputs_pspec, - enable_partition=True) - global_mask_shape, mask_named_sharding, mask = shard_array_wrapper(test_ds['mask'], - batch_size, - mesh, - masks_pspec, - enable_partition=True) - global_label_shape, label_named_sharding, label = shard_array_wrapper(test_ds['label'], - batch_size, - mesh, - labels_pspec, - enable_partition=True) + global_input_shape, input_named_sharding, sentence = shard_array_wrapper( + test_ds["sentence"], batch_size, mesh, inputs_pspec, enable_partition=True + ) + global_mask_shape, mask_named_sharding, mask = shard_array_wrapper( + test_ds["mask"], batch_size, mesh, masks_pspec, enable_partition=True + ) + global_label_shape, label_named_sharding, label = shard_array_wrapper( + test_ds["label"], batch_size, mesh, labels_pspec, enable_partition=True + ) all_loss = [] all_accuracy = [] for batch_input, batch_mask, batch_label in zip(sentence, mask, label): - shard_input = jax.make_array_from_single_device_arrays(global_input_shape, - input_named_sharding, [batch_input]) - shard_mask = jax.make_array_from_single_device_arrays(global_mask_shape, - mask_named_sharding, [batch_mask]) - shard_label = jax.make_array_from_single_device_arrays(global_label_shape, - label_named_sharding, [batch_label]) + shard_input = jax.make_array_from_single_device_arrays( + global_input_shape, input_named_sharding, [batch_input] + ) + shard_mask = jax.make_array_from_single_device_arrays( + global_mask_shape, mask_named_sharding, [batch_mask] + ) + shard_label = jax.make_array_from_single_device_arrays( + global_label_shape, label_named_sharding, [batch_label] + ) loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect) all_loss.append(loss) @@ -226,12 +250,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_ps def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download('punkt') - dataset_size = len(dataset['sentence']) + nltk.download("punkt") + dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) - for j, sentence in enumerate(dataset['sentence']): + for j, sentence in enumerate(dataset["sentence"]): tokens = nltk.word_tokenize(sentence) tensor = output[j] @@ -251,9 +275,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): mask_2d[:seq_len, :seq_len] = 0 new_dataset = { - 'sentence': output, - 'label': dataset['label'].astype(np.float32), - 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) + "sentence": output, + "label": dataset["label"].astype(np.float32), + "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)), } return new_dataset, vocab, word_id @@ -263,12 +287,12 @@ def get_datasets(max_seq_len): vocab = {} word_id = 0 - train_ds = load_dataset('glue', 'cola', split='train') - train_ds.set_format(type='np') + train_ds = load_dataset("glue", "cola", split="train") + train_ds.set_format(type="np") train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) - test_ds = load_dataset('glue', 'cola', split='validation') - test_ds.set_format(type='np') + test_ds = load_dataset("glue", "cola", split="validation") + test_ds.set_format(type="np") test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) return train_ds, test_ds, word_id @@ -277,7 +301,8 @@ def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} assert "fp8_" in str( - jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) + jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs) + ) def get_params_pspec(sharding_rules, abs_var_collect): @@ -313,10 +338,12 @@ def train_and_evaluate(args): print(args) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) - jax.distributed.initialize(coordinator_address=args.coordinator_address, - num_processes=args.num_process, - process_id=args.process_id, - local_device_ids=args.process_id) + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_process, + process_id=args.process_id, + local_device_ids=args.process_id, + ) assert jax.local_device_count() == 1, "1 GPU per process" num_gpu_tp = 2 @@ -328,12 +355,14 @@ def train_and_evaluate(args): num_gpu_tp = 1 assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}" - assert args.test_batch_size % num_gpu_dp == 0, \ - f"Test batch size needs to be multiple of {num_gpu_dp}" + assert ( + args.test_batch_size % num_gpu_dp == 0 + ), f"Test batch size needs to be multiple of {num_gpu_dp}" device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) - with jax.sharding.Mesh(devices=device_mesh, - axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)) as shard_mesh: + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) + ) as shard_mesh: rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -344,9 +373,9 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast(args.use_fp8, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, - None)): + with te.fp8_autocast( + args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None) + ): encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) @@ -359,18 +388,21 @@ def train_and_evaluate(args): masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) in_shardings = (None, inputs_pspec, masks_pspec) - out_shardings = {key: params_pspec if key is PARAMS_KEY else None \ - for key in abs_var_collect} + out_shardings = { + key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect + } pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) var_collect = pjit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) - state = train_state.TrainState.create(apply_fn=encoder.apply, - params=params, - tx=optimizer) + state = train_state.TrainState.create( + apply_fn=encoder.apply, params=params, tx=optimizer + ) state_pspec = get_state_pspec(state, params_pspec) - labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,) + labels_pspec = jax.sharding.PartitionSpec( + DEVICE_DP_AXIS, + ) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) out_shardings = (state_pspec, None, None, None) @@ -396,18 +428,37 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step, - shard_mesh, inputs_pspec, masks_pspec, labels_pspec) - - test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, - var_collect, pjit_eval_step, shard_mesh, - inputs_pspec, masks_pspec, labels_pspec) + state, + train_ds, + args.batch_size, + rngs, + var_collect, + pjit_train_step, + shard_mesh, + inputs_pspec, + masks_pspec, + labels_pspec, + ) + + test_loss, test_accuracy = eval_model( + state, + test_ds, + args.test_batch_size, + var_collect, + pjit_eval_step, + shard_mesh, + inputs_pspec, + masks_pspec, + labels_pspec, + ) if args.process_id == 0: - print(f"Epoch: {epoch:>2} " - f"Train Loss: {train_loss:.6f} " - f"Train Accuracy: {train_accuracy:.6f} " - f"Test Loss: {test_loss:.6f} " - f"Test Accuracy: {test_accuracy:.6f} ") + print( + f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} " + ) jax.distributed.shutdown() return [train_loss, train_accuracy, test_loss, test_accuracy] @@ -458,24 +509,31 @@ def encoder_parser(args): help="quickly check a single pass", ) parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") - parser.add_argument("--use-fp8", - action="store_true", - default=False, - help="Use FP8 for inference and training without recalibration") - parser.add_argument("--coordinator-address", - type=str, - default="127.0.0.1:1234", - help="the IP address of process 0 and a port on \ - which that process should launch a coordinator service \ - (default: 127.0.0.1:1234)") - parser.add_argument("--num-process", - type=int, - default=1, - help="number of processes (default: 1)") - parser.add_argument("--process-id", - type=int, - default=0, - help="the ID number of the current process (default: 0)") + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration", + ) + parser.add_argument( + "--coordinator-address", + type=str, + default="127.0.0.1:1234", + help=( + "the IP address of process 0 and a port on which that" + " process should launch a coordinator service (default:" + " 127.0.0.1:1234)" + ), + ) + parser.add_argument( + "--num-process", type=int, default=1, help="number of processes (default: 1)" + ) + parser.add_argument( + "--process-id", + type=int, + default=0, + help="the ID number of the current process (default: 0)", + ) return parser.parse_args(args) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index b89243792550f5cb6e0724a0dca2e0363321ca03..674f7de81526157d4e2fa228421bb0f39b2af105 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -19,30 +19,33 @@ from flax.training import train_state import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -PARAMS_KEY = 'params' -DROPOUT_KEY = 'dropout' -INPUT_KEY = 'input_rng' +PARAMS_KEY = "params" +DROPOUT_KEY = "dropout" +INPUT_KEY = "input_rng" class Net(nn.Module): """NLP Encoder""" + num_embed: int @nn.compact def __call__(self, x, mask, disable_dropout=False): x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) - te_Encoder = partial(te_flax.TransformerLayer, - hidden_size=256, - mlp_hidden_size=1024, - num_attention_heads=8, - hidden_dropout=0.1, - attention_dropout=0.1, - dropout_rng_name=DROPOUT_KEY, - layer_type=te_flax.TransformerLayerType.ENCODER, - self_attn_mask_type='padding', - enable_relative_embedding=False, - dtype=jnp.bfloat16) + te_Encoder = partial( + te_flax.TransformerLayer, + hidden_size=256, + mlp_hidden_size=1024, + num_attention_heads=8, + hidden_dropout=0.1, + attention_dropout=0.1, + dropout_rng_name=DROPOUT_KEY, + layer_type=te_flax.TransformerLayerType.ENCODER, + self_attn_mask_type="padding", + enable_relative_embedding=False, + dtype=jnp.bfloat16, + ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) @@ -78,20 +81,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def train_epoch(state, train_ds, batch_size, rngs, var_collect): """Train for a single epoch.""" - train_ds_size = len(train_ds['sentence']) + train_ds_size = len(train_ds["sentence"]) steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) - perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch perms = perms.reshape((steps_per_epoch, batch_size)) epoch_loss = [] epoch_accuracy = [] for perm in perms: - batch_inputs = train_ds['sentence'][perm, ...] - batch_masks = train_ds['mask'][perm, ...] - batch_labels = train_ds['label'][perm, ...] - state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks, - batch_labels, var_collect, rngs) + batch_inputs = train_ds["sentence"][perm, ...] + batch_masks = train_ds["mask"][perm, ...] + batch_labels = train_ds["label"][perm, ...] + state, loss, accuracy, var_collect = train_step( + state, batch_inputs, batch_masks, batch_labels, var_collect, rngs + ) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -118,7 +122,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def eval_model(state, test_ds, batch_size, var_collect): """Evaluation loop.""" - test_ds_size = len(test_ds['sentence']) + test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size valid_size = num_steps * batch_size all_loss = [] @@ -126,9 +130,9 @@ def eval_model(state, test_ds, batch_size, var_collect): for batch_start in range(0, valid_size, batch_size): batch_end = batch_start + batch_size - batch_inputs = test_ds['sentence'][batch_start:batch_end] - batch_masks = test_ds['mask'][batch_start:batch_end] - batch_labels = test_ds['label'][batch_start:batch_end] + batch_inputs = test_ds["sentence"][batch_start:batch_end] + batch_masks = test_ds["mask"][batch_start:batch_end] + batch_labels = test_ds["label"][batch_start:batch_end] loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect) all_loss.append(loss) all_accuracy.append(accuracy) @@ -140,12 +144,12 @@ def eval_model(state, test_ds, batch_size, var_collect): def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download('punkt') - dataset_size = len(dataset['sentence']) + nltk.download("punkt") + dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) - for j, sentence in enumerate(dataset['sentence']): + for j, sentence in enumerate(dataset["sentence"]): tokens = nltk.word_tokenize(sentence) tensor = output[j] @@ -165,9 +169,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): mask_2d[:seq_len, :seq_len] = 0 new_dataset = { - 'sentence': output, - 'label': dataset['label'].astype(np.float32), - 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) + "sentence": output, + "label": dataset["label"].astype(np.float32), + "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)), } return new_dataset, vocab, word_id @@ -177,12 +181,12 @@ def get_datasets(max_seq_len): vocab = {} word_id = 0 - train_ds = load_dataset('glue', 'cola', split='train') - train_ds.set_format(type='np') + train_ds = load_dataset("glue", "cola", split="train") + train_ds.set_format(type="np") train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) - test_ds = load_dataset('glue', 'cola', split='validation') - test_ds.set_format(type='np') + test_ds = load_dataset("glue", "cola", split="validation") + test_ds.set_format(type="np") test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) return train_ds, test_ds, word_id @@ -191,7 +195,8 @@ def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} assert "fp8_" in str( - jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) + jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs) + ) def train_and_evaluate(args): @@ -214,9 +219,9 @@ def train_and_evaluate(args): masks = jnp.zeros(mask_shape, dtype=jnp.uint8) var_collect = encoder.init(init_rngs, inputs, masks) tx = optax.adamw(args.lr) - state = train_state.TrainState.create(apply_fn=encoder.apply, - params=var_collect[PARAMS_KEY], - tx=tx) + state = train_state.TrainState.create( + apply_fn=encoder.apply, params=var_collect[PARAMS_KEY], tx=tx + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -235,15 +240,18 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect) + state, train_ds, args.batch_size, rngs, var_collect + ) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) - print(f"Epoch: {epoch:>2} " - f"Train Loss: {train_loss:.6f} " - f"Train Accuracy: {train_accuracy:.6f} " - f"Test Loss: {test_loss:.6f} " - f"Test Accuracy: {test_accuracy:.6f} ") + print( + f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} " + ) return [train_loss, train_accuracy, test_loss, test_accuracy] @@ -293,10 +301,12 @@ def encoder_parser(args): help="quickly check a single pass", ) parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") - parser.add_argument("--use-fp8", - action="store_true", - default=False, - help="Use FP8 for inference and training without recalibration") + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration", + ) return parser.parse_args(args) diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index ae74a66337205f400dcb372424766a331f5bfa3a..ff431a261b4495db2f99a0e3afd18c5eebd16db4 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -1,7 +1,7 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -""" MNIST training on single GPU""" +"""MNIST training on single GPU""" import argparse import unittest from functools import partial @@ -20,13 +20,14 @@ import transformer_engine.jax.flax as te_flax IMAGE_H = 28 IMAGE_W = 28 IMAGE_C = 1 -PARAMS_KEY = 'params' -DROPOUT_KEY = 'dropout' -INPUT_KEY = 'input_rng' +PARAMS_KEY = "params" +DROPOUT_KEY = "dropout" +INPUT_KEY = "input_rng" class Net(nn.Module): """CNN model for MNIST.""" + use_te: bool = False @nn.compact @@ -83,17 +84,17 @@ def update_model(state, grads): def train_epoch(state, train_ds, batch_size, rngs, var_collect): """Train for a single epoch.""" - train_ds_size = len(train_ds['image']) + train_ds_size = len(train_ds["image"]) steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) - perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch perms = perms.reshape((steps_per_epoch, batch_size)) epoch_loss = [] epoch_accuracy = [] for perm in perms: - batch_images = train_ds['image'][perm, ...] - batch_labels = train_ds['label'][perm, ...] + batch_images = train_ds["image"][perm, ...] + batch_labels = train_ds["label"][perm, ...] grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs) state, var_collect = update_model(state, grads) epoch_loss.append(loss) @@ -106,7 +107,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): def eval_model(state, test_ds, batch_size, var_collect): """Evaluation loop.""" - test_ds_size = len(test_ds['image']) + test_ds_size = len(test_ds["image"]) num_steps = test_ds_size // batch_size valid_size = num_steps * batch_size all_loss = [] @@ -114,8 +115,8 @@ def eval_model(state, test_ds, batch_size, var_collect): for batch_start in range(0, valid_size, batch_size): batch_end = batch_start + batch_size - batch_images = test_ds['image'][batch_start:batch_end] - batch_labels = test_ds['label'][batch_start:batch_end] + batch_images = test_ds["image"][batch_start:batch_end] + batch_labels = test_ds["label"][batch_start:batch_end] _, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect) all_loss.append(loss) all_accuracy.append(accuracy) @@ -127,21 +128,21 @@ def eval_model(state, test_ds, batch_size, var_collect): def get_datasets(): """Load MNIST train and test datasets into memory.""" - train_ds = load_dataset('mnist', split='train') - train_ds.set_format(type='np') - batch_size = train_ds['image'].shape[0] + train_ds = load_dataset("mnist", split="train") + train_ds.set_format(type="np") + batch_size = train_ds["image"].shape[0] shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C) new_train_ds = { - 'image': train_ds['image'].astype(np.float32).reshape(shape) / 255., - 'label': train_ds['label'] + "image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0, + "label": train_ds["label"], } - test_ds = load_dataset('mnist', split='test') - test_ds.set_format(type='np') - batch_size = test_ds['image'].shape[0] + test_ds = load_dataset("mnist", split="test") + test_ds.set_format(type="np") + batch_size = test_ds["image"].shape[0] shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C) new_test_ds = { - 'image': test_ds['image'].astype(np.float32).reshape(shape) / 255., - 'label': test_ds['label'] + "image": test_ds["image"].astype(np.float32).reshape(shape) / 255.0, + "label": test_ds["label"], } return new_train_ds, new_test_ds @@ -149,8 +150,13 @@ def get_datasets(): def check_fp8(state, var_collect, input_shape, label_shape): "Check if model includes FP8." assert "f8_" in str( - jax.make_jaxpr(apply_model)(state, jnp.empty(input_shape, dtype=jnp.bfloat16), - jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect)) + jax.make_jaxpr(apply_model)( + state, + jnp.empty(input_shape, dtype=jnp.bfloat16), + jnp.empty(label_shape, dtype=jnp.bfloat16), + var_collect, + ) + ) def train_and_evaluate(args): @@ -173,17 +179,21 @@ def train_and_evaluate(args): cnn = Net(args.use_te) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) tx = optax.sgd(args.lr, args.momentum) - state = train_state.TrainState.create(apply_fn=cnn.apply, - params=var_collect[PARAMS_KEY], - tx=tx) + state = train_state.TrainState.create( + apply_fn=cnn.apply, params=var_collect[PARAMS_KEY], tx=tx + ) if args.use_fp8: check_fp8(state, var_collect, input_shape, label_shape) if args.dry_run: - apply_model(state, jnp.empty(input_shape, dtype=jnp.bfloat16), - jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect, - {DROPOUT_KEY: dropout_rng}) + apply_model( + state, + jnp.empty(input_shape, dtype=jnp.bfloat16), + jnp.empty(label_shape, dtype=jnp.bfloat16), + var_collect, + {DROPOUT_KEY: dropout_rng}, + ) print("PASSED") return None @@ -193,14 +203,17 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect) + state, train_ds, args.batch_size, rngs, var_collect + ) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) - print(f"Epoch: {epoch:>2} " - f"Train Loss: {train_loss:.6f} " - f"Train Accuracy: {train_accuracy:.6f} " - f"Test Loss: {test_loss:.6f} " - f"Test Accuracy: {test_accuracy:.6f} ") + print( + f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} " + ) return [train_loss, train_accuracy, test_loss, test_accuracy] @@ -250,15 +263,18 @@ def mnist_parser(args): help="quickly check a single pass", ) parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument("--use-fp8", - action="store_true", - default=False, - help="Use FP8 for inference and training without recalibration. " \ - "It also enables Transformer Engine implicitly.") - parser.add_argument("--use-te", - action="store_true", - default=False, - help="Use Transformer Engine") + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help=( + "Use FP8 for inference and training without recalibration. " + "It also enables Transformer Engine implicitly." + ), + ) + parser.add_argument( + "--use-te", action="store_true", default=False, help="Use Transformer Engine" + ) return parser.parse_args(args) diff --git a/examples/paddle/mnist/test_single_gpu_mnist.py b/examples/paddle/mnist/test_single_gpu_mnist.py index c1eacf39dac99c18bc38aebe53b878d5c87c378f..de5c9e9b6cca759dcdaaa07b7667bf3be0d54ef5 100644 --- a/examples/paddle/mnist/test_single_gpu_mnist.py +++ b/examples/paddle/mnist/test_single_gpu_mnist.py @@ -59,7 +59,9 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8): model.train() losses = [] for batch_id, (data, labels) in enumerate(train_loader): - with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager + with paddle.amp.auto_cast( + dtype="bfloat16", level="O2" + ): # pylint: disable=not-context-manager with te.fp8_autocast(enabled=use_fp8): outputs = model(data) loss = F.cross_entropy(outputs, labels) @@ -70,10 +72,12 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8): optimizer.clear_gradients() if batch_id % args.log_interval == 0: - print(f"Train Epoch: {epoch} " - f"[{batch_id * len(data)}/{len(train_loader.dataset)} " - f"({100. * batch_id / len(train_loader):.0f}%)]\t" - f"Loss: {loss.item():.6f}") + print( + f"Train Epoch: {epoch} " + f"[{batch_id * len(data)}/{len(train_loader.dataset)} " + f"({100. * batch_id / len(train_loader):.0f}%)]\t" + f"Loss: {loss.item():.6f}" + ) if args.dry_run: return loss.item() avg_loss = sum(losses) / len(losses) @@ -89,7 +93,9 @@ def evaluate(model, test_loader, epoch, use_fp8): with paddle.no_grad(): for data, labels in test_loader: - with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager + with paddle.amp.auto_cast( + dtype="bfloat16", level="O2" + ): # pylint: disable=not-context-manager with te.fp8_autocast(enabled=use_fp8): outputs = model(data) acc = metric.compute(outputs, labels) @@ -104,7 +110,9 @@ def calibrate(model, test_loader): with paddle.no_grad(): for data, _ in test_loader: - with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager + with paddle.amp.auto_cast( + dtype="bfloat16", level="O2" + ): # pylint: disable=not-context-manager with te.fp8_autocast(enabled=False, calibrating=True): _ = model(data) @@ -160,20 +168,27 @@ def mnist_parser(args): metavar="N", help="how many batches to wait before logging training status", ) - parser.add_argument("--use-fp8", - action="store_true", - default=False, - help="Use FP8 for inference and training without recalibration. " \ - "It also enables Transformer Engine implicitly.") - parser.add_argument("--use-fp8-infer", - action="store_true", - default=False, - help="Use FP8 for inference only. If not using FP8 for training, " - "calibration is performed for FP8 infernece.") - parser.add_argument("--use-te", - action="store_true", - default=False, - help="Use Transformer Engine") + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help=( + "Use FP8 for inference and training without recalibration. " + "It also enables Transformer Engine implicitly." + ), + ) + parser.add_argument( + "--use-fp8-infer", + action="store_true", + default=False, + help=( + "Use FP8 for inference only. If not using FP8 for training, " + "calibration is performed for FP8 infernece." + ), + ) + parser.add_argument( + "--use-te", action="store_true", default=False, help="Use Transformer Engine" + ) args = parser.parse_args(args) return args @@ -185,9 +200,9 @@ def train_and_evaluate(args): paddle.seed(args.seed) # Load MNIST dataset - transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW') - train_dataset = MNIST(mode='train', transform=transform) - val_dataset = MNIST(mode='test', transform=transform) + transform = Normalize(mean=[127.5], std=[127.5], data_format="CHW") + train_dataset = MNIST(mode="train", transform=transform) + val_dataset = MNIST(mode="test", transform=transform) # Define data loaders train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) @@ -198,7 +213,7 @@ def train_and_evaluate(args): optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters()) # Cast model to BF16 - model = paddle.amp.decorate(models=model, level='O2', dtype='bfloat16') + model = paddle.amp.decorate(models=model, level="O2", dtype="bfloat16") for epoch in range(1, args.epochs + 1): loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8) @@ -209,7 +224,7 @@ def train_and_evaluate(args): if args.save_model or args.use_fp8_infer: paddle.save(model.state_dict(), "mnist_cnn.pdparams") - print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8)) + print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8)) weights = paddle.load("mnist_cnn.pdparams") model.set_state_dict(weights) acc = evaluate(model, val_loader, 0, args.use_fp8) @@ -235,8 +250,10 @@ class TestMNIST(unittest.TestCase): assert actual[0] < desired_traing_loss assert actual[1] > desired_test_accuracy - @unittest.skipIf(paddle.device.cuda.get_device_capability() < (8, 0), - "BF16 MNIST example requires Ampere+ GPU") + @unittest.skipIf( + paddle.device.cuda.get_device_capability() < (8, 0), + "BF16 MNIST example requires Ampere+ GPU", + ) def test_te_bf16(self): """Test Transformer Engine with BF16""" self.args.use_te = True diff --git a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py b/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py index f0905d76975e9d4de6ffdb1faf4cb91018634b37..619dbaf9d736bbe185bd7a978c869fe145611874 100644 --- a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py @@ -15,62 +15,77 @@ import torch.distributed as dist import transformer_engine.pytorch as te from transformer_engine.common.recipe import Format, DelayedScaling + def parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser( - description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers.") - parser.add_argument('-i', "--num-iters", type=int, default=5, - help="Number of dummy 'training' iterations.") - parser.add_argument('-b', "--batch-size", type=int, default=2, - help="Input batch size.") - parser.add_argument('-s', "--seq-length", type=int, default=2048, - help="Input sequence length.") - parser.add_argument('-n', "--num-heads", type=int, default=64, - help="Number of attention heads.") - parser.add_argument('-d', "--head-dim", type=int, default=128, - help="Dimension of each attention head.") - parser.add_argument("--mlp-expansion-factor", type=int, default=4, - help="MLP block intermediate size as a factor of hidden dimension.") - parser.add_argument("--seed", type=int, default=1234, - help="RNG seed.") - parser.add_argument("--fp8", action="store_true", default=False, - help="Enables the te.fp8_autocast() context.") - parser.add_argument("--no-comm-overlap", action="store_true", default=False, - help="Disable the comm+GEMM overlap.") - parser.add_argument('-v', "--verbose", action="store_true", default=False) + description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers." + ) + parser.add_argument( + "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations." + ) + parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=64, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." + ) + parser.add_argument( + "--mlp-expansion-factor", + type=int, + default=4, + help="MLP block intermediate size as a factor of hidden dimension.", + ) + parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument( + "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + ) + parser.add_argument( + "--no-comm-overlap", + action="store_true", + default=False, + help="Disable the comm+GEMM overlap.", + ) + parser.add_argument("-v", "--verbose", action="store_true", default=False) return parser.parse_args(argv, namespace) + def train(opts): WORLD_RANK = int(os.getenv("RANK")) WORLD_SIZE = int(os.getenv("WORLD_SIZE")) - def dist_print(msg, end='\n', all_ranks=False): + + def dist_print(msg, end="\n", all_ranks=False): if WORLD_RANK == 0 or all_ranks: print(f"[RANK-{WORLD_RANK}] {msg}", end=end) # Seed RNG torch.cuda.set_device(WORLD_RANK) - torch.manual_seed(opts.seed+WORLD_RANK) - torch.cuda.manual_seed(opts.seed+WORLD_RANK) + torch.manual_seed(opts.seed + WORLD_RANK) + torch.cuda.manual_seed(opts.seed + WORLD_RANK) # Initialize torch.distributed global process group and get TP group - dist.init_process_group(backend="nccl", - rank=WORLD_RANK, - world_size=WORLD_SIZE, - device_id=torch.device(f'cuda:{WORLD_RANK}')) + dist.init_process_group( + backend="nccl", + rank=WORLD_RANK, + world_size=WORLD_SIZE, + device_id=torch.device(f"cuda:{WORLD_RANK}"), + ) tp_group = dist.new_group(backend="nccl") tp_size = dist.get_world_size(tp_group) # Intialize userbuffers ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad - 'method': 'ring_exchange', - 'num_splits' : 8, - 'num_sm' : 1, - 'set_sm_margin' : False, + "method": "ring_exchange", + "num_splits": 8, + "num_sm": 1, + "set_sm_margin": False, } rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop - 'method': 'ring_exchange', - 'num_splits' : 4, - 'num_sm' : 1, - 'set_sm_margin' : True, + "method": "ring_exchange", + "num_splits": 4, + "num_sm": 1, + "set_sm_margin": True, } hidden_size = opts.num_heads * opts.head_dim batched_size = opts.seq_length * opts.batch_size @@ -78,30 +93,31 @@ def train(opts): te.initialize_ub( [batched_size, hidden_size], tp_group, - use_fp8 = opts.fp8, - dtype = torch.bfloat16, - ub_cfgs = { - 'fc1_fprop': ag_cfg, - 'fc1_dgrad': rs_cfg, - 'fc2_fprop': rs_cfg, - 'fc2_dgrad': ag_cfg, + use_fp8=opts.fp8, + dtype=torch.bfloat16, + ub_cfgs={ + "fc1_fprop": ag_cfg, + "fc1_dgrad": rs_cfg, + "fc2_fprop": rs_cfg, + "fc2_dgrad": ag_cfg, }, ) # model = te.LayerNormMLP( - hidden_size, opts.mlp_expansion_factor * hidden_size, - params_dtype = torch.bfloat16, - device = 'cuda', - tp_group = tp_group, - tp_size = tp_size, - set_parallel_mode = True, - sequence_parallel = True, # this is required for comm+GEMM overlap - seq_length = opts.seq_length, - micro_batch_size = opts.batch_size, - ub_overlap_rs_dgrad = not opts.no_comm_overlap, - ub_overlap_rs = not opts.no_comm_overlap, - ub_overlap_ag = not opts.no_comm_overlap, + hidden_size, + opts.mlp_expansion_factor * hidden_size, + params_dtype=torch.bfloat16, + device="cuda", + tp_group=tp_group, + tp_size=tp_size, + set_parallel_mode=True, + sequence_parallel=True, # this is required for comm+GEMM overlap + seq_length=opts.seq_length, + micro_batch_size=opts.batch_size, + ub_overlap_rs_dgrad=not opts.no_comm_overlap, + ub_overlap_rs=not opts.no_comm_overlap, + ub_overlap_ag=not opts.no_comm_overlap, ) # Initialize optimizer with model parameters @@ -109,16 +125,19 @@ def train(opts): # Fp8 recipe setup fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, - amax_compute_algo="max") + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") # Start dummy "training" iterations for i in range(opts.num_iters): dist_print(f"Iter {i+1}", all_ranks=opts.verbose) dist_print("|-- Generate random input batch", all_ranks=opts.verbose) - x = torch.rand((opts.seq_length // tp_size, opts.batch_size, hidden_size), - dtype=torch.bfloat16, device='cuda', requires_grad=True) + x = torch.rand( + (opts.seq_length // tp_size, opts.batch_size, hidden_size), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) dist_print("|-- Forward pass", all_ranks=opts.verbose) with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group): @@ -135,17 +154,15 @@ def train(opts): te.destroy_ub() dist.destroy_process_group() + if __name__ == "__main__": if "TORCHELASTIC_RUN_ID" in os.environ.keys(): args = parse_args() train(args) else: subprocess.run( - [ - 'torchrun', f'--nproc-per-node={torch.cuda.device_count()}', - *sys.argv - ], + ["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv], env=os.environ, - check=True + check=True, ) os._exit(0) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index e78c1b1582c95d59b866d469f7cf0c80dabff762..cf0a75c3365e7fb9aab693adda85bbb015de667a 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -14,7 +14,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, - checkpoint_wrapper + checkpoint_wrapper, ) import transformer_engine.pytorch as te @@ -29,46 +29,56 @@ rng_seed = 1234 torch.manual_seed(rng_seed) torch.cuda.manual_seed(rng_seed) CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker() -CUDA_RNG_STATES_TRACKER.add('model-parallel-rng', rng_seed) +CUDA_RNG_STATES_TRACKER.add("model-parallel-rng", rng_seed) + + def get_cuda_rng_tracker(): return CUDA_RNG_STATES_TRACKER + def apply_fsdp_checkpointing(model, blocks): """apply activation checkpointing to model returns None as model is updated directly """ - wrapper = lambda m: checkpoint_wrapper(m, - checkpoint_fn=te.distributed.checkpoint, - use_reentrant=False, - get_rng_state_tracker=get_cuda_rng_tracker) + wrapper = lambda m: checkpoint_wrapper( + m, + checkpoint_fn=te.distributed.checkpoint, + use_reentrant=False, + get_rng_state_tracker=get_cuda_rng_tracker, + ) check_fn = lambda submodule: isinstance(submodule, blocks) apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) + def lowercase(s): return str(s).lower() + def torch_dtype(d): typemap = { - 'fp32' : torch.float32, - 'float32' : torch.float32, - 'fp16' : torch.float16, - 'float16' : torch.float16, - 'bf16' : torch.bfloat16, - 'bfloat16' : torch.bfloat16 + "fp32": torch.float32, + "float32": torch.float32, + "fp16": torch.float16, + "float16": torch.float16, + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, } if lowercase(d) not in typemap.keys(): raise TypeError return typemap[lowercase(d)] + te_layer_map = { - 'linear': te.Linear, - 'layernorm': te.LayerNorm, - 'rmsnorm': te.RMSNorm, - 'layernormlinear': te.LayerNormLinear, - 'layernormmlp': te.LayerNormMLP, - 'multiheadattention': te.MultiheadAttention, - 'transformerlayer': te.TransformerLayer + "linear": te.Linear, + "layernorm": te.LayerNorm, + "rmsnorm": te.RMSNorm, + "layernormlinear": te.LayerNormLinear, + "layernormmlp": te.LayerNormMLP, + "multiheadattention": te.MultiheadAttention, + "transformerlayer": te.TransformerLayer, } + + def te_layer(l): if l is not None: if lowercase(l) not in te_layer_map.keys(): @@ -76,74 +86,120 @@ def te_layer(l): return te_layer_map[lowercase(l)] return None + def get_layer_args(opts): hidden_size = opts.num_heads * opts.head_dim - layer_args = (hidden_size, ) + layer_args = (hidden_size,) layer_kwargs = { - 'params_dtype': opts.dtype, - 'device': 'cuda' if opts.no_defer_init else 'meta', - 'get_rng_state_tracker': get_cuda_rng_tracker, + "params_dtype": opts.dtype, + "device": "cuda" if opts.no_defer_init else "meta", + "get_rng_state_tracker": get_cuda_rng_tracker, } if opts.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]: ffn_hidden_size = 3 * hidden_size if opts.num_layers == 1 else hidden_size - layer_args += (ffn_hidden_size, ) - layer_kwargs['bias'] = True + layer_args += (ffn_hidden_size,) + layer_kwargs["bias"] = True if opts.layer_type == te.LayerNormMLP: - layer_kwargs['seq_length'] = opts.seq_length + layer_kwargs["seq_length"] = opts.seq_length elif opts.layer_type == te.MultiheadAttention: - layer_args += (opts.num_heads, ) - layer_kwargs['fuse_qkv_params'] = True - layer_kwargs['input_layernorm'] = True + layer_args += (opts.num_heads,) + layer_kwargs["fuse_qkv_params"] = True + layer_kwargs["input_layernorm"] = True elif opts.layer_type == te.TransformerLayer: layer_args += (3 * hidden_size, opts.num_heads) - layer_kwargs['fuse_qkv_params'] = True - layer_kwargs['seq_length'] = opts.seq_length + layer_kwargs["fuse_qkv_params"] = True + layer_kwargs["seq_length"] = opts.seq_length return layer_args, layer_kwargs + def parse_fsdp_args(): - parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " + - "torch.distributed.fsdp.FullyShardedDataParallel strategy.") - parser.add_argument('-v', "--verbose", action="store_true", default=False, - help="Print out information from all GPUs instead of only the root GPU-0.") - parser.add_argument('-b', "--batch-size", type=int, default=32, - help="Input batch size.") - parser.add_argument('-s', "--seq-length", type=int, default=1048, - help="Input sequence length.") - parser.add_argument('-n', "--num-heads", type=int, default=16, - help="Number of attention heads.") - parser.add_argument('-d', "--head-dim", type=int, default=128, - help="Dimension of each attention head (number of KV channels).") - parser.add_argument('-i', "--num-iters", type=int, default=5, - help="Number of dummy 'training' iterations.") - parser.add_argument('-k', "--num-layers", type=int, default=3, - help="Number of modules chained together with nn.Sequential.") - parser.add_argument("--layer-type", type=te_layer, default=te.TransformerLayer, - choices=list(te_layer_map.values()), - help="TE module type used to construct the test model.") - parser.add_argument("--seed", type=int, default=1234, - help="PyTorch RNG seed.") - parser.add_argument("--profile-memory", action="store_true", - help="Enable memory profiling via torch.profiler.profile().") - parser.add_argument("--profile-name", type=str, default=None, - help="File path for memory profiling.") - parser.add_argument("--checkpoint-layer", type=te_layer, default=None, - help="Recompute activations of the selected layer during the backward " + \ - "pass instead of saving.") - parser.add_argument("--no-fp8", action="store_true", default=False, - help="Disables the te.fp8_autocast() context.") - parser.add_argument("--no-defer-init", action="store_true", - help="Defer module parameter initialization until after FSDP sharding.") - parser.add_argument("--no-te-fsdp", action="store_true", - help="Disable sharding of intermediate/activation tensors in TE modules.") - parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16, - help="Data type for input tensor and Transformer Engine module parameters.") + parser = argparse.ArgumentParser( + description="Run Transformer Engine modules with the " + + "torch.distributed.fsdp.FullyShardedDataParallel strategy." + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="Print out information from all GPUs instead of only the root GPU-0.", + ) + parser.add_argument("-b", "--batch-size", type=int, default=32, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=1048, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=16, help="Number of attention heads." + ) + parser.add_argument( + "-d", + "--head-dim", + type=int, + default=128, + help="Dimension of each attention head (number of KV channels).", + ) + parser.add_argument( + "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations." + ) + parser.add_argument( + "-k", + "--num-layers", + type=int, + default=3, + help="Number of modules chained together with nn.Sequential.", + ) + parser.add_argument( + "--layer-type", + type=te_layer, + default=te.TransformerLayer, + choices=list(te_layer_map.values()), + help="TE module type used to construct the test model.", + ) + parser.add_argument("--seed", type=int, default=1234, help="PyTorch RNG seed.") + parser.add_argument( + "--profile-memory", + action="store_true", + help="Enable memory profiling via torch.profiler.profile().", + ) + parser.add_argument( + "--profile-name", type=str, default=None, help="File path for memory profiling." + ) + parser.add_argument( + "--checkpoint-layer", + type=te_layer, + default=None, + help="Recompute activations of the selected layer during the backward " + + "pass instead of saving.", + ) + parser.add_argument( + "--no-fp8", + action="store_true", + default=False, + help="Disables the te.fp8_autocast() context.", + ) + parser.add_argument( + "--no-defer-init", + action="store_true", + help="Defer module parameter initialization until after FSDP sharding.", + ) + parser.add_argument( + "--no-te-fsdp", + action="store_true", + help="Disable sharding of intermediate/activation tensors in TE modules.", + ) + parser.add_argument( + "--dtype", + type=torch_dtype, + default=torch.bfloat16, + help="Data type for input tensor and Transformer Engine module parameters.", + ) return parser.parse_args() + def dist_print(text, all_ranks=False, no_new_line=False): if LOCAL_RANK == 0 or all_ranks: - end = '' if no_new_line else '\n' + end = "" if no_new_line else "\n" print(f"[GPU-{LOCAL_RANK}] " + text, end=end) + def train(opts): # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") @@ -157,7 +213,7 @@ def train(opts): te_layer_list = [] for i in range(opts.num_layers): if opts.layer_type in [te.MultiheadAttention, te.TransformerLayer]: - layer_kwargs['layer_number'] = i+1 + layer_kwargs["layer_number"] = i + 1 te_layer_list.append(opts.layer_type(*layer_args, **layer_kwargs)) te_model = nn.Sequential(*te_layer_list) else: @@ -171,20 +227,23 @@ def train(opts): # Wrap the model with FSDP # NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and # controls all communication. - all_gpus = dist.new_group(backend='nccl') + all_gpus = dist.new_group(backend="nccl") fsdp_wrap_policy = always_wrap_policy if opts.layer_type == te.TransformerLayer: # NOTE: FSDP causes illegal memory access without this special policy for Transformers - fsdp_wrap_policy = partial(transformer_auto_wrap_policy, - transformer_layer_cls={te.TransformerLayer}) - te_model = FullyShardedDataParallel(te_model, - process_group=all_gpus, - use_orig_params=True, - mixed_precision=MixedPrecision( - param_dtype=opts.dtype, - reduce_dtype=torch.float32, - ), - auto_wrap_policy=fsdp_wrap_policy) + fsdp_wrap_policy = partial( + transformer_auto_wrap_policy, transformer_layer_cls={te.TransformerLayer} + ) + te_model = FullyShardedDataParallel( + te_model, + process_group=all_gpus, + use_orig_params=True, + mixed_precision=MixedPrecision( + param_dtype=opts.dtype, + reduce_dtype=torch.float32, + ), + auto_wrap_policy=fsdp_wrap_policy, + ) if opts.checkpoint_layer is not None: # Recompute the activations of the selected layer during the backward pass instead of @@ -218,8 +277,13 @@ def train(opts): for i in range(opts.num_iters): # Generate a random input batch - x = torch.rand(opts.seq_length, opts.batch_size, opts.num_heads*opts.head_dim, - dtype=opts.dtype, device='cuda') + x = torch.rand( + opts.seq_length, + opts.batch_size, + opts.num_heads * opts.head_dim, + dtype=opts.dtype, + device="cuda", + ) # fp8_autocast needs to be given the FSDP process group for amax reductions with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus): y = te_model(x) @@ -230,7 +294,6 @@ def train(opts): optim.zero_grad(set_to_none=True) del x - if opts.profile_memory: torch.cuda.memory._dump_snapshot(f"gpu{LOCAL_RANK}_{opts.profile_name}.pickle") torch.cuda.memory._record_memory_history(enabled=None) @@ -238,7 +301,7 @@ def train(opts): end.record() torch.cuda.synchronize() peak_mem = torch.cuda.max_memory_allocated() - train_time = start.elapsed_time(end)/1000. + train_time = start.elapsed_time(end) / 1000.0 dist_print(f"Training Time: {train_time}s") dist_print(f"Avg. Iter. Time: {train_time / opts.num_iters}s") dist_print(f"Peak Memory Use: {peak_mem * 1e-6}MBs") diff --git a/examples/pytorch/mnist/main.py b/examples/pytorch/mnist/main.py index d8bfcfd4808791f9ff35af311b27162a6f0605d9..2a003f0a0d464b961c15f5a63ef0d07624e50126 100644 --- a/examples/pytorch/mnist/main.py +++ b/examples/pytorch/mnist/main.py @@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8): with te.fp8_autocast(enabled=fp8, calibrating=True): output = model(data) + def test(model, device, test_loader, use_fp8): """Testing function.""" model.eval() @@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8): data, target = data.to(device), target.to(device) with te.fp8_autocast(enabled=use_fp8): output = model(data) - test_loss += F.nll_loss( - output, target, reduction="sum" - ).item() # sum up batch loss - pred = output.argmax( - dim=1, keepdim=True - ) # get the index of the max log-probability + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) @@ -150,9 +147,7 @@ def main(): default=False, help="quickly check a single pass", ) - parser.add_argument( - "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" - ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument( "--log-interval", type=int, @@ -167,7 +162,10 @@ def main(): help="For Saving the current Model", ) parser.add_argument( - "--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" + "--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration", ) parser.add_argument( "--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only" @@ -215,7 +213,7 @@ def main(): if args.save_model or args.use_fp8_infer: torch.save(model.state_dict(), "mnist_cnn.pt") - print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer)) + print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer)) weights = torch.load("mnist_cnn.pt") model.load_state_dict(weights) test(model, device, test_loader, args.use_fp8_infer) diff --git a/qa/L0_license/copyright_checker.py b/qa/L0_license/copyright_checker.py index 310610fac6f6ae1197c32c6e821fa14428e3125e..46a3a6d4fe9005005df43d74c5502d470a885261 100644 --- a/qa/L0_license/copyright_checker.py +++ b/qa/L0_license/copyright_checker.py @@ -18,21 +18,26 @@ path = sys.argv[1] config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json" + class bcolors: - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + def print_ok(msg): print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}") + def print_fail(msg): print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}") + def print_warn(msg): print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}") + with open(config_path, "r") as f: c = json.load(f) current_year = datetime.date.today().year @@ -41,7 +46,7 @@ with open(config_path, "r") as f: else: year_string = str(c["initial_year"]) + "-" + str(current_year) copyright_string = c["copyright"].replace("", year_string) - license = c["license"].split('\n') + license = c["license"].split("\n") excludes = c["exclude"] root_path = os.path.abspath(path) copyright_only = c["copyright_only"] @@ -49,36 +54,42 @@ with open(config_path, "r") as f: has_gitignore = os.path.exists(root_path + "/.gitignore") + def strip_star_slash(s): ret = s - if ret.startswith('*'): + if ret.startswith("*"): ret = ret[1:] - if ret.endswith('/'): + if ret.endswith("/"): ret = ret[:-1] return ret + if has_gitignore: with open(root_path + "/.gitignore", "r") as f: for line in f.readlines(): excludes.append(strip_star_slash(line.strip())) + def get_file_type(path): - ext = {"c": ["c", "cpp", "cu", "h", "cuh"], - "py": ["py"], - "rst": ["rst"], - "txt": ["txt"], - "cfg": ["cfg"], - "sh": ["sh"], - "md": ["md"], - } + ext = { + "c": ["c", "cpp", "cu", "h", "cuh"], + "py": ["py"], + "rst": ["rst"], + "txt": ["txt"], + "cfg": ["cfg"], + "sh": ["sh"], + "md": ["md"], + } tmp = path.split(".") for filetype, ext_list in ext.items(): if tmp[-1] in ext_list: return filetype return "unknown" + success = True + def check_file(path): global success N = 10 @@ -127,9 +138,10 @@ def check_file(path): if copyright_found and license_found: print_ok("OK") + for root, dirs, files in os.walk(root_path): print(f"Entering {root}") - hidden = [d for d in dirs if d.startswith('.')] + [f for f in files if f.startswith('.')] + hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")] all_excludes = excludes + hidden to_remove = [] for d in dirs: diff --git a/setup.py b/setup.py index d6493447633ca243e1de4be0b4a6e351d734839d..d2cc91d65a16357dd0bd54eb6467d5cad2eeaa2b 100644 --- a/setup.py +++ b/setup.py @@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve() from setuptools.command.build_ext import build_ext as BuildExtension + if "pytorch" in frameworks: from torch.utils.cpp_extension import BuildExtension elif "paddle" in frameworks: from paddle.utils.cpp_extension import BuildExtension elif "jax" in frameworks: - install_and_import('pybind11') + install_and_import("pybind11") from pybind11.setup_helpers import build_ext as BuildExtension @@ -86,34 +87,45 @@ if __name__ == "__main__": if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: from build_tools.pytorch import setup_pytorch_extension + ext_modules.append( setup_pytorch_extension( "transformer_engine/pytorch/csrc", current_file_path / "transformer_engine" / "pytorch" / "csrc", - current_file_path / "transformer_engine")) + current_file_path / "transformer_engine", + ) + ) if "jax" in frameworks: from build_tools.jax import setup_jax_extension + ext_modules.append( setup_jax_extension( "transformer_engine/jax/csrc", current_file_path / "transformer_engine" / "jax" / "csrc", - current_file_path / "transformer_engine")) + current_file_path / "transformer_engine", + ) + ) if "paddle" in frameworks: from build_tools.paddle import setup_paddle_extension + ext_modules.append( setup_paddle_extension( "transformer_engine/paddle/csrc", current_file_path / "transformer_engine" / "paddle" / "csrc", - current_file_path / "transformer_engine")) + current_file_path / "transformer_engine", + ) + ) # Configure package setuptools.setup( name="transformer_engine", version=__version__, packages=setuptools.find_packages( - include=["transformer_engine", - "transformer_engine.*", - "transformer_engine/build_tools"], + include=[ + "transformer_engine", + "transformer_engine.*", + "transformer_engine/build_tools", + ], ), extras_require={ "test": test_requires, @@ -125,5 +137,5 @@ if __name__ == "__main__": install_requires=install_requires, license_files=("LICENSE",), include_package_data=True, - package_data={"": ["VERSION.txt"]} + package_data={"": ["VERSION.txt"]}, ) diff --git a/tests/cpp/operator/test_causal_softmax.cu b/tests/cpp/operator/test_causal_softmax.cu index 1d7d70e5722b527c46e938efe27a9aac23d0fb93..640434674b40e6aa48d27cc27a5ba59611372699 100644 --- a/tests/cpp/operator/test_causal_softmax.cu +++ b/tests/cpp/operator/test_causal_softmax.cu @@ -132,7 +132,7 @@ void compute_bwd_ref( for (int b = 0; b < batches; ++b) { for (int h = 0; h < heads; ++h) { size_t offset = b * batch_size + h * head_size; - compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset, + compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset, buff + offset, scaling_factor, batches, heads, rows, cols); } } diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 5f1aaa4c397751c4b1909c90dd3c272d2472ce00..55494c42d631245c1aedc532bbc23024b2326319 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -6,7 +6,7 @@ import jax import pytest -@pytest.fixture(autouse=True, scope='function') +@pytest.fixture(autouse=True, scope="function") def clear_live_arrays(): """ Clear all live arrays to keep the resource clean diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 1360af19fadc955664f05210f3b41eddd98ffd3b..3a7fe333785a77c3e6dcf016bac331e9f0424d32 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough def generate_configs(): configs = [] if is_devices_enough(2): - configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')]) - configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')]) - + configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")]) + configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")]) + if is_devices_enough(4): TP_size = 2 DP_size = 2 configs.append( - [4, (DP_size, TP_size), ('dp', 'tp'), - MeshResource(dp_resource='dp', tp_resource='tp')]) + [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")] + ) return configs @@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref): bytes_count = 0 def get_bytes_per_txt(t): - ''' + """ The pattern of t would be like: 'f32[]', '(f32[1024]{0}', @@ -54,24 +54,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref): 'f8E4M3FN[1024]{0}', 'i32[1024]{0}', 'bf16[1024,1024]{0}' - ''' - match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t) + """ + match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t) _, bits_of_type, shape = match.groups() bytes_of_type = int(bits_of_type) // 8 - if shape == '': + if shape == "": num_of_elements = 1 else: - num_of_elements = reduce(operator.mul, map(int, shape.split(','))) + num_of_elements = reduce(operator.mul, map(int, shape.split(","))) return bytes_of_type * num_of_elements # ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...] - if '(' in hlo_text[2]: + if "(" in hlo_text[2]: for txt in hlo_text[2:]: bytes_count += get_bytes_per_txt(txt) - if ')' in txt: + if ")" in txt: break - else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...] + else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...] bytes_count = get_bytes_per_txt(hlo_text[2]) return bytes_count @@ -91,21 +91,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref): return result target_result = count_collectives(target_splitted_hlo) - assert target_result == coll_count_ref, \ - f"Expected collective count is {coll_count_ref}, but got {target_result}." - - -def compare_ops(target_func, - ref_func, - inputs, - coll_count_ref, - *, - grad_args=None, - metric_fwd_dtype=None, - metric_bwd_dtype=None, - in_shardings=_UNSPECIFIED, - out_shardings=_UNSPECIFIED, - **kwargs): + assert ( + target_result == coll_count_ref + ), f"Expected collective count is {coll_count_ref}, but got {target_result}." + + +def compare_ops( + target_func, + ref_func, + inputs, + coll_count_ref, + *, + grad_args=None, + metric_fwd_dtype=None, + metric_bwd_dtype=None, + in_shardings=_UNSPECIFIED, + out_shardings=_UNSPECIFIED, + **kwargs, +): assert len(inputs) >= 1 if metric_fwd_dtype is None: diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1e0b424e2fc92b5cfe9e6a454329ce4f27d285cb..8664a03f8d6af2b0c92f7e7ac9ff420e11a8526a 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -15,24 +15,10 @@ from jax import jit, value_and_grad from flax import linen as nn from utils import assert_allclose -from transformer_engine.jax.dot import ( - type_safe_dot_general, - dequantize, - quantize -) -from transformer_engine.jax.fp8 import ( - FP8MetaPackage, - FP8Helper, - is_fp8_available -) -from transformer_engine.jax.layernorm import ( - layernorm, - layernorm_fp8_dot -) -from transformer_engine.jax.layernorm_mlp import ( - activation_lu, - fused_layernorm_fp8_mlp -) +from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize +from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available +from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot +from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp from transformer_engine.jax import cpp_extensions as tex GEMM_CASES = [ @@ -50,11 +36,11 @@ is_fp8_supported, reason = is_fp8_available() def _convert_to_activation_function(fn_or_string): """Convert a string to an activation function.""" - if fn_or_string == 'linear': + if fn_or_string == "linear": return lambda x: x - if fn_or_string == 'quick_gelu': + if fn_or_string == "quick_gelu": return lambda x: nn.gelu(x, approximate=True) - if fn_or_string == 'squared_relu': + if fn_or_string == "squared_relu": return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)]) if isinstance(fn_or_string, str): return getattr(nn, fn_or_string) @@ -93,7 +79,7 @@ class TestFP8Dot: assert_allclose(z, x, dtype=jnp.float8_e4m3fn) - @pytest.mark.parametrize('m,n,k', GEMM_CASES) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) def test_forward_bf16(self, m, n, k): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) @@ -106,7 +92,7 @@ class TestFP8Dot: assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', GEMM_CASES) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) def test_forward_fp8_randint(self, m, n, k): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) @@ -135,7 +121,7 @@ class TestFP8Dot: assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) - @pytest.mark.parametrize('m,n,k', GEMM_CASES) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) def test_grad_bf16(self, m, n, k): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) @@ -161,7 +147,7 @@ class TestFP8Dot: assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', GEMM_CASES) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) def test_grad_fp8_dot(self, m, n, k): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) @@ -192,33 +178,38 @@ class TestFP8Dot: ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b) for _ in range(3): - primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, - scale_list) = value_n_grad_primitive_func(a, b, amax_list, scale_list) + primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = ( + value_n_grad_primitive_func(a, b, amax_list, scale_list) + ) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', [(256, 128, 512), - (16384, 1024, 2816), - (16384, 2816, 1024), - (16384, 1024, 1024)]) - @pytest.mark.parametrize('activation_type', [('gelu', ), - ('gelu', 'linear'), - ('silu', ), - ('silu', 'linear'), - ('relu',), - ('relu', 'linear'), - ('quick_gelu',), - ('quick_gelu', 'linear'), - ('squared_relu',), - ('squared_relu', 'linear')]) - @pytest.mark.parametrize('use_bias', [True, False]) - def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, activation_type: Sequence[Union[str, - Callable]], - use_bias: bool): - """ N/a """ + @pytest.mark.parametrize( + "m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)] + ) + @pytest.mark.parametrize( + "activation_type", + [ + ("gelu",), + ("gelu", "linear"), + ("silu",), + ("silu", "linear"), + ("relu",), + ("relu", "linear"), + ("quick_gelu",), + ("quick_gelu", "linear"), + ("squared_relu",), + ("squared_relu", "linear"), + ], + ) + @pytest.mark.parametrize("use_bias", [True, False]) + def test_grad_fused_layernorm_fp8_mlp( + self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool + ): + """N/a""" key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 6) @@ -233,8 +224,9 @@ class TestFP8Dot: b1 = None b2 = None - def primitive_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, - scale_list_2): + def primitive_func( + x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2 + ): # x is input tensor, matrix 2d # y, z are weights, matrix 2d # out = ((x * y) + w) * z + v @@ -255,18 +247,31 @@ class TestFP8Dot: scale_list_2[2], ) return jnp.mean( - fused_layernorm_fp8_mlp(x, - ln_s, - None, [y, z], [w, v], [fp8_meta_pkg_1, fp8_meta_pkg_2], - "rmsnorm", - activation_type=activation_type, - use_bias=use_bias)) - - def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, - kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray, - amax_list_1: List[jnp.ndarray], amax_list_2: List[jnp.ndarray], - scale_list_1: List[jnp.ndarray], - scale_list_2: List[jnp.ndarray]) -> jnp.ndarray: + fused_layernorm_fp8_mlp( + x, + ln_s, + None, + [y, z], + [w, v], + [fp8_meta_pkg_1, fp8_meta_pkg_2], + "rmsnorm", + activation_type=activation_type, + use_bias=use_bias, + ) + ) + + def layernorm_fp8_mlp_ref( + x: jnp.ndarray, + ln_scale: jnp.ndarray, + kernel_1: jnp.ndarray, + kernel_2: jnp.ndarray, + bias_1: jnp.ndarray, + bias_2: jnp.ndarray, + amax_list_1: List[jnp.ndarray], + amax_list_2: List[jnp.ndarray], + scale_list_1: List[jnp.ndarray], + scale_list_2: List[jnp.ndarray], + ) -> jnp.ndarray: x = jnp.asarray(x, jnp.float32) mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) @@ -315,11 +320,14 @@ class TestFP8Dot: def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2): return jnp.mean( - layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, - scale_list_2)) + layernorm_fp8_mlp_ref( + x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2 + ) + ) value_n_grad_primitive_func = jit( - value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))) + value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) + ) value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))) _, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta() @@ -339,40 +347,87 @@ class TestFP8Dot: # Convert str to index as str is not a valid type for JAX JIT for _ in range(3): - ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad, - ref_amax_list_1, ref_amax_list_2, ref_scale_list_1, - ref_scale_list_2) = value_n_grad_ref_func(a, s, k1, k2, b1, b2, - ref_amax_list_1, ref_amax_list_2, - ref_scale_list_1, ref_scale_list_2) + ref_out, ( + ref_a_grad, + ref_s_grad, + ref_k1_grad, + ref_k2_grad, + ref_b1_grad, + ref_b2_grad, + ref_amax_list_1, + ref_amax_list_2, + ref_scale_list_1, + ref_scale_list_2, + ) = value_n_grad_ref_func( + a, + s, + k1, + k2, + b1, + b2, + ref_amax_list_1, + ref_amax_list_2, + ref_scale_list_1, + ref_scale_list_2, + ) for _ in range(3): - primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, - primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, - primitive_amax_list_1, primitive_amax_list_2, primitive_scale_list_1, - primitive_scale_list_2) = value_n_grad_primitive_func( - a, s, k1, k2, b1, b2, primitive_amax_list_1, primitive_amax_list_2, - primitive_scale_list_1, primitive_scale_list_2) + primitive_out, ( + primitive_a_grad, + primitive_s_grad, + primitive_k1_grad, + primitive_k2_grad, + primitive_b1_grad, + primitive_b2_grad, + primitive_amax_list_1, + primitive_amax_list_2, + primitive_scale_list_1, + primitive_scale_list_2, + ) = value_n_grad_primitive_func( + a, + s, + k1, + k2, + b1, + b2, + primitive_amax_list_1, + primitive_amax_list_2, + primitive_scale_list_1, + primitive_scale_list_2, + ) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) - assert_allclose(jnp.asarray(primitive_a_grad, np.float32), - jnp.asarray(ref_a_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), - jnp.asarray(ref_k1_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_s_grad, np.float32), - jnp.asarray(ref_s_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_k2_grad, np.float32), - jnp.asarray(ref_k2_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) + assert_allclose( + jnp.asarray(primitive_a_grad, np.float32), + jnp.asarray(ref_a_grad, np.float32), + dtype=FP8Helper.BWD_DTYPE, + ) + assert_allclose( + jnp.asarray(primitive_k1_grad, np.float32), + jnp.asarray(ref_k1_grad, np.float32), + dtype=FP8Helper.BWD_DTYPE, + ) + assert_allclose( + jnp.asarray(primitive_s_grad, np.float32), + jnp.asarray(ref_s_grad, np.float32), + dtype=FP8Helper.BWD_DTYPE, + ) + assert_allclose( + jnp.asarray(primitive_k2_grad, np.float32), + jnp.asarray(ref_k2_grad, np.float32), + dtype=FP8Helper.BWD_DTYPE, + ) if use_bias: - assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), - jnp.asarray(ref_b2_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), - jnp.asarray(ref_b1_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) + assert_allclose( + jnp.asarray(primitive_b2_grad, np.float32), + jnp.asarray(ref_b2_grad, np.float32), + dtype=FP8Helper.BWD_DTYPE, + ) + assert_allclose( + jnp.asarray(primitive_b1_grad, np.float32), + jnp.asarray(ref_b1_grad, np.float32), + dtype=FP8Helper.BWD_DTYPE, + ) @pytest.fixture(name="random_inputs") @@ -402,17 +457,22 @@ class TestActivationLu: def primitive_func(self, inputs): return jnp.mean(activation_lu(inputs, activation_type=self.activation_type)) - @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)]) - @pytest.mark.parametrize('activation_type', [('gelu',), - ('gelu', 'linear'), - ('silu',), - ('silu', 'linear'), - ('relu',), - ('relu', 'linear'), - ('quick_gelu',), - ('quick_gelu', 'linear'), - ('squared_relu',), - ('squared_relu', 'linear') ]) + @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)]) + @pytest.mark.parametrize( + "activation_type", + [ + ("gelu",), + ("gelu", "linear"), + ("silu",), + ("silu", "linear"), + ("relu",), + ("relu", "linear"), + ("quick_gelu",), + ("quick_gelu", "linear"), + ("squared_relu",), + ("squared_relu", "linear"), + ], + ) def test_activation_lu(self, random_inputs, activation_type): x = random_inputs x = jnp.repeat(x, len(activation_type), axis=1) @@ -441,23 +501,34 @@ class TestActivationLuFP8(TestActivationLu): return output def _prim_func_fwd(x, _x_t, _dbias, _amax): - activation_lu_out, _ = tex.act_lu_fp8(x, amax, scale, scale_inv, - FP8Helper.FWD_DTYPE, activation_type) + activation_lu_out, _ = tex.act_lu_fp8( + x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type + ) activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv) - ctx = (x) + ctx = x return activation_lu_out, ctx def _prim_func_bwd(ctx, g): x = ctx - if len(self.activation_type) > 1: #gated, no bias - dactivation_lu, dactivation_lu_trans, amax_out = \ - tex.dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, - FP8Helper.BWD_DTYPE, -1, activation_type) + if len(self.activation_type) > 1: # gated, no bias + dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose( + g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type + ) dbias = jnp.empty(x.shape[-1], x.dtype) - else: #not gated, with bias - dactivation_lu, dactivation_lu_trans, dbias, amax_out = \ - tex.dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, - -1, -2, self.activation_type) + else: # not gated, with bias + dactivation_lu, dactivation_lu_trans, dbias, amax_out = ( + tex.dact_lu_dbias_cast_transpose( + g, + x, + amax, + scale, + scale_inv, + FP8Helper.BWD_DTYPE, + -1, + -2, + self.activation_type, + ) + ) dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv) dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv) ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out) @@ -468,23 +539,28 @@ class TestActivationLuFP8(TestActivationLu): dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) amax_no_use = jnp.zeros(1, jnp.float32) - value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d: - jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)) + value_n_grad_primitive_func = value_and_grad( + lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3) + ) return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)]) - @pytest.mark.parametrize('activation_type', [('gelu',), - ('gelu', 'linear'), - ('silu',), - ('silu', 'linear'), - ('relu',), - ('relu', 'linear'), - ('quick_gelu',), - ('quick_gelu', 'linear'), - ('squared_relu',), - ('squared_relu', 'linear') ]) + @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)]) + @pytest.mark.parametrize( + "activation_type", + [ + ("gelu",), + ("gelu", "linear"), + ("silu",), + ("silu", "linear"), + ("relu",), + ("relu", "linear"), + ("quick_gelu",), + ("quick_gelu", "linear"), + ("squared_relu",), + ("squared_relu", "linear"), + ], + ) def test_activation_lu(self, random_inputs, activation_type): self.amax = jnp.zeros(1, jnp.float32) self.scale = jnp.ones(1, jnp.float32) @@ -500,12 +576,14 @@ class TestActivationLuFP8(TestActivationLu): assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2) - if 'linear' not in activation_type: + if "linear" not in activation_type: assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1)))) assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) - assert_allclose(prim_grad_trans, - jnp.transpose(ref_grad, self.transpose_indices), - dtype=FP8Helper.BWD_DTYPE) + assert_allclose( + prim_grad_trans, + jnp.transpose(ref_grad, self.transpose_indices), + dtype=FP8Helper.BWD_DTYPE, + ) class TestNorm: @@ -536,34 +614,38 @@ class TestNorm: """ x_ = jnp.asarray(x, jnp.float32) if bias is None: - mean = 0. + mean = 0.0 else: mean = jnp.mean(x_, axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) if zero_centered_gamma: - scale += 1. + scale += 1.0 if bias is None: - bias = 0. + bias = 0.0 return jnp.asarray(normed_input * scale + bias).astype(x.dtype) - @pytest.mark.parametrize('n, hidden', LN_CASES) - @pytest.mark.parametrize('dtype', DTYPES) - @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm']) - @pytest.mark.parametrize('zero_centered_gamma', [False, True]) - @pytest.mark.parametrize('epsilon', [1e-2, 1e-6]) - def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon, - dtype): + @pytest.mark.parametrize("n, hidden", LN_CASES) + @pytest.mark.parametrize("dtype", DTYPES) + @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"]) + @pytest.mark.parametrize("zero_centered_gamma", [False, True]) + @pytest.mark.parametrize("epsilon", [1e-2, 1e-6]) + def test_layernorm_forward_backward( + self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype + ): """ Test transformer_engine.jax.layernorm.layernorm """ expect_assert = False - if ln_type == 'rmsnorm' and zero_centered_gamma: + if ln_type == "rmsnorm" and zero_centered_gamma: # zero_centered_gamma is not supported for rmsnorm, expect an assertion. expect_assert = True - with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*" - ) if expect_assert else nullcontext(): + with ( + pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*") + if expect_assert + else nullcontext() + ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 3) @@ -571,7 +653,7 @@ class TestNorm: gamma_range = (-1, 1) if zero_centered_gamma else (0, 2) gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range) gamma = jnp.asarray(gamma, dtype) - if ln_type == 'layernorm': + if ln_type == "layernorm": beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) beta = jnp.asarray(beta, dtype) else: @@ -585,19 +667,27 @@ class TestNorm: jitted_primitive = jit( value_and_grad( lambda x, gamma, beta: compute_loss( - layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)), - (0, 1, 2))) + layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon) + ), + (0, 1, 2), + ) + ) jitted_reference = jit( value_and_grad( lambda x, gamma, beta: compute_loss( - self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)), - (0, 1, 2))) + self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon) + ), + (0, 1, 2), + ) + ) - primitive_out, (primitive_dx, primitive_dgamma, - primitive_dbeta) = jitted_primitive(x, gamma, beta) - reference_out, (reference_dx, reference_dgamma, - reference_dbeta) = jitted_reference(x, gamma, beta) + primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive( + x, gamma, beta + ) + reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference( + x, gamma, beta + ) assert_allclose(primitive_out, reference_out, dtype=dtype) assert_allclose(primitive_dx, reference_dx, dtype=dtype) @@ -606,21 +696,24 @@ class TestNorm: assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', GEMM_CASES) - @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm']) - @pytest.mark.parametrize('zero_centered_gamma', [True, False]) - @pytest.mark.parametrize('epsilon', [1e-2, 1e-6]) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) + @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"]) + @pytest.mark.parametrize("zero_centered_gamma", [True, False]) + @pytest.mark.parametrize("epsilon", [1e-2, 1e-6]) def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon): """ Test transformer_engine.jax.layernorm.layernorm_fp8_dot """ expect_assert = False - if ln_type == 'rmsnorm' and zero_centered_gamma: + if ln_type == "rmsnorm" and zero_centered_gamma: # zero_centered_gamma is not supported for rmsnorm, expect an assertion. expect_assert = True - with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*" - ) if expect_assert else nullcontext(): + with ( + pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*") + if expect_assert + else nullcontext() + ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) @@ -628,7 +721,7 @@ class TestNorm: b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) - if ln_type == 'layernorm': + if ln_type == "layernorm": beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) else: beta = None @@ -644,8 +737,9 @@ class TestNorm: amax_list_1[2], scale_list_1[2], ) - primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type, - zero_centered_gamma) + primitive_out = layernorm_fp8_dot( + x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma + ) return jnp.mean(primitive_out) def ref_func(x, y, gamma, beta, zero_centered_gamma): @@ -655,14 +749,19 @@ class TestNorm: value_n_grad_primitive_func = value_and_grad(primitive_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3)) - ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, - ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma) + ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = ( + value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma) + ) for _ in range(3): - primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad, - primitive_beta_grad, amax_list_1, - scale_list_1) = value_n_grad_primitive_func( - a, b, gamma, beta, amax_list_1, scale_list_1) + primitive_out, ( + primitive_a_grad, + primitive_b_grad, + primitive_gamma_grad, + primitive_beta_grad, + amax_list_1, + scale_list_1, + ) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 3efee867982c19117c09fb56ea306ddd8841a0c6..8a6c5792ff517f77800f8ef35e5daa760c7cb599 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -9,16 +9,8 @@ import jax.numpy as jnp import numpy as np from flax.linen import dot_product_attention from jax import random -from jax.sharding import ( - Mesh, - NamedSharding, - PartitionSpec -) -from distributed_test_base import ( - generate_configs, - generate_collectives_count, - compare_ops -) +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from distributed_test_base import generate_configs, generate_collectives_count, compare_ops from utils import make_causal_mask, make_self_mask from transformer_engine.jax import fp8_autocast from transformer_engine.jax.attention import ( @@ -27,7 +19,7 @@ from transformer_engine.jax.attention import ( fused_attn_kvpacked, AttnBiasType, AttnMaskType, - QKVLayout + QKVLayout, ) @@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16] class TestDistributedSelfAttn: - def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, - dtype): + def generate_collectives_count_ref( + self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype + ): jax_dtype = jax.dtypes.canonicalize_dtype(dtype) _, seqlen, _, heads, _ = shape is_dp_enabled = mesh_resource.dp_resource is not None @@ -46,7 +39,7 @@ class TestDistributedSelfAttn: idx = mesh_axes.index(mesh_resource.tp_resource) tp_size = mesh_shape[idx] - all_reduce_loss_bytes = 4 # 1 * FP32 + all_reduce_loss_bytes = 4 # 1 * FP32 bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled) # for loss and dbias @@ -57,8 +50,11 @@ class TestDistributedSelfAttn: qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) - bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \ - if with_bias else None + bias = ( + random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) + if with_bias + else None + ) mask = None if attn_mask_type == AttnMaskType.PADDING_MASK: @@ -66,47 +62,76 @@ class TestDistributedSelfAttn: elif attn_mask_type == AttnMaskType.CAUSAL_MASK: mask = make_self_mask(batch, seqlen) - qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, - None) - bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \ - if with_bias else None - mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ - if attn_mask_type != AttnMaskType.NO_MASK else None + qkv_pspec = PartitionSpec( + mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None + ) + bias_pspec = ( + PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None + ) + mask_pspec = ( + PartitionSpec(mesh_resource.dp_resource, None, None, None) + if attn_mask_type != AttnMaskType.NO_MASK + else None + ) return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) - @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) - @pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) @pytest.mark.parametrize( - 'attn_bias_type', - [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS]) - @pytest.mark.parametrize('attn_mask_type', - [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) - @pytest.mark.parametrize('dtype', DTYPES) - def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, - attn_bias_type, attn_mask_type, dtype): + "attn_bias_type", + [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS], + ) + @pytest.mark.parametrize( + "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] + ) + @pytest.mark.parametrize("dtype", DTYPES) + def test_self_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + attn_bias_type, + attn_mask_type, + dtype, + ): dropout_prob = 0.0 is_training = True scaling_factor = 1.0 _, seqlen, _, num_head, hidden = data_shape - if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, - attn_mask_type, dropout_prob, num_head, num_head, - seqlen, seqlen, hidden): + if not is_fused_attn_kernel_available( + dtype, + dtype, + QKVLayout.BS3HD, + attn_bias_type, + attn_mask_type, + dropout_prob, + num_head, + num_head, + seqlen, + seqlen, + hidden, + ): pytest.skip(f"No FusedAttn backwend found") def target_func(qkv, bias, mask): return jnp.mean( - fused_attn_qkvpacked(qkv, - bias, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training)) + fused_attn_qkvpacked( + qkv, + bias, + mask, + None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training, + ) + ) def ref_func(qkv, bias, mask): query, key, value = jnp.split(qkv, [1, 2], axis=-3) @@ -114,52 +139,59 @@ class TestDistributedSelfAttn: key = jnp.squeeze(key) value = jnp.squeeze(value) - output = dot_product_attention(query, - key, - value, - bias=bias, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32) + output = dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + deterministic=is_training, + dropout_rate=dropout_prob, + dropout_rng=None, + dtype=jnp.float32, + ) return jnp.mean(output).astype(dtype) with_bias = attn_bias_type != AttnBiasType.NO_BIAS - (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \ - self.generate_inputs(data_shape, mesh_resource, with_bias, - attn_mask_type, dtype) - collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes, - mesh_resource, with_bias, - data_shape, dtype) + (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs( + data_shape, mesh_resource, with_bias, attn_mask_type, dtype + ) + collective_count_ref = self.generate_collectives_count_ref( + mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype + ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast(mesh_resource=mesh_resource): qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) - bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \ - if bias is not None else bias - mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ - if mask is not None else mask + bias_ = ( + jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias + ) + mask_ = ( + jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask + ) grad_args = (0, 1) if with_bias else (0,) out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) - compare_ops(target_func, - ref_func, [qkv_, bias_, mask_], - collective_count_ref, - grad_args=grad_args, - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(qkv_pspec, bias_pspec, mask_pspec), - out_shardings=(None, out_grad_shardings)) + compare_ops( + target_func, + ref_func, + [qkv_, bias_, mask_], + collective_count_ref, + grad_args=grad_args, + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(qkv_pspec, bias_pspec, mask_pspec), + out_shardings=(None, out_grad_shardings), + ) class TestDistributedCrossAttn: def generate_collectives_count_ref(self): # for loss - all_reduce_loss_bytes = 4 # 1 * FP32 + all_reduce_loss_bytes = 4 # 1 * FP32 return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype): @@ -176,20 +208,26 @@ class TestDistributedCrossAttn: q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) - kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, - None) - mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ - if attn_mask_type != AttnMaskType.NO_MASK else None + kv_pspec = PartitionSpec( + mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None + ) + mask_pspec = ( + PartitionSpec(mesh_resource.dp_resource, None, None, None) + if attn_mask_type != AttnMaskType.NO_MASK + else None + ) return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) - @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) - @pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]]) - @pytest.mark.parametrize('attn_mask_type', - [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) - @pytest.mark.parametrize('dtype', DTYPES) - def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, - attn_mask_type, dtype): + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]]) + @pytest.mark.parametrize( + "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] + ) + @pytest.mark.parametrize("dtype", DTYPES) + def test_cross_attn( + self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype + ): attn_bias_type = AttnBiasType.NO_BIAS dropout_prob = 0.0 is_training = True @@ -197,23 +235,36 @@ class TestDistributedCrossAttn: _, seqlen, num_head, hidden = data_shape - if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type, - attn_mask_type, dropout_prob, num_head, num_head, - seqlen, seqlen, hidden): + if not is_fused_attn_kernel_available( + dtype, + dtype, + QKVLayout.BSHD_BS2HD, + attn_bias_type, + attn_mask_type, + dropout_prob, + num_head, + num_head, + seqlen, + seqlen, + hidden, + ): pytest.skip(f"No FusedAttn backwend found") def target_func(q, kv, mask): return jnp.mean( - fused_attn_kvpacked(q, - kv, - None, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training)) + fused_attn_kvpacked( + q, + kv, + None, + mask, + None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training, + ) + ) def ref_func(query, kv, mask): key, value = jnp.split(kv, [1], axis=-3) @@ -221,34 +272,41 @@ class TestDistributedCrossAttn: key = jnp.squeeze(key) value = jnp.squeeze(value) - output = dot_product_attention(query, - key, - value, - bias=None, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32) + output = dot_product_attention( + query, + key, + value, + bias=None, + mask=mask, + deterministic=is_training, + dropout_rate=dropout_prob, + dropout_rng=None, + dtype=jnp.float32, + ) return jnp.mean(output).astype(dtype) - (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \ - self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype) + (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs( + data_shape, mesh_resource, attn_mask_type, dtype + ) collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast(mesh_resource=mesh_resource): q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) - mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ - if mask is not None else mask - - compare_ops(target_func, - ref_func, [q_, kv_, mask_], - collective_count_ref, - grad_args=(0, 1), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(q_pspec, kv_pspec, mask_pspec), - out_shardings=(None, (q_pspec, kv_pspec))) + mask_ = ( + jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask + ) + + compare_ops( + target_func, + ref_func, + [q_, kv_, mask_], + collective_count_ref, + grad_args=(0, 1), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(q_pspec, kv_pspec, mask_pspec), + out_shardings=(None, (q_pspec, kv_pspec)), + ) diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 3aa5b9ae203a9d9e37d313706bbd92896aa86834..f0dd56feaaa737f7498a6af5ac9cec3332a05177 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -35,31 +35,44 @@ class TestDistributedLayernorm: else: raise NotImplementedError - g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None) + g_pspec = b_pspec = ( + PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None) + ) return (x, gamma, beta), (x_pspec, g_pspec, b_pspec) def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype): jax_dtype = jax.dtypes.canonicalize_dtype(dtype) is_dp_enabled = mesh_resource.dp_resource is not None - assert ln_type in ['layernorm', 'rmsnorm'] - all_reduce_loss_bytes = 4 # 1 * FP32 + assert ln_type in ["layernorm", "rmsnorm"] + all_reduce_loss_bytes = 4 # 1 * FP32 # for loss, dgamma and dbeta - weight_count = 2 if ln_type == 'layernorm' else 1 - allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize - return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled), - allgather=0, - other=0) - - @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) - @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) - @pytest.mark.parametrize('dtype', DTYPES) - @pytest.mark.parametrize('zero_centered_gamma', [False, True]) - @pytest.mark.parametrize('shard_weights', [False, True]) - def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, - zero_centered_gamma, shard_weights): + weight_count = 2 if ln_type == "layernorm" else 1 + allreduce_total_bytes = ( + all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize + ) + return generate_collectives_count( + allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0 + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]]) + @pytest.mark.parametrize("dtype", DTYPES) + @pytest.mark.parametrize("zero_centered_gamma", [False, True]) + @pytest.mark.parametrize("shard_weights", [False, True]) + def test_layernorm( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + dtype, + zero_centered_gamma, + shard_weights, + ): epsilon = 1e-6 - ln_type = 'layernorm' + ln_type = "layernorm" def target_func(x, gamma, beta): return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)) @@ -75,10 +88,12 @@ class TestDistributedLayernorm: output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype) return jnp.mean(output) - (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \ - self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) - collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, - data_shape, dtype) + (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs( + data_shape, mesh_resource, dtype, shard_weights + ) + collective_count_ref = self.generate_collectives_count_ref( + mesh_resource, ln_type, data_shape, dtype + ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast(mesh_resource=mesh_resource): @@ -88,20 +103,25 @@ class TestDistributedLayernorm: with warnings.catch_warnings(record=True) as warns: try: - compare_ops(target_func, - ref_func, [x_, gamma_, beta_], - collective_count_ref, - grad_args=(0, 1, 2), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(x_pspec, g_pspec, b_pspec), - out_shardings=(None, (x_pspec, g_pspec, b_pspec))) + compare_ops( + target_func, + ref_func, + [x_, gamma_, beta_], + collective_count_ref, + grad_args=(0, 1, 2), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(x_pspec, g_pspec, b_pspec), + out_shardings=(None, (x_pspec, g_pspec, b_pspec)), + ) except AssertionError as err: # Layernorm should still produce the correct numerical result with # gamma/beta sharded. However, the collective count may not be the same # when XLA is forced to unshard gamma and/or beta. We can catch # and ignore that specific error here. - if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err): + if ( + g_pspec[-1] is None and b_pspec[-1] is None + ) or "Expected collective count" not in str(err): raise err finally: for w in warns: @@ -110,13 +130,15 @@ class TestDistributedLayernorm: "unsupported sharding of gamma and/or beta" ) - @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) - @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) - @pytest.mark.parametrize('dtype', DTYPES) - @pytest.mark.parametrize('shard_weights', [False, True]) - def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights): + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]]) + @pytest.mark.parametrize("dtype", DTYPES) + @pytest.mark.parametrize("shard_weights", [False, True]) + def test_rmsnorm( + self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights + ): epsilon = 1e-6 - ln_type = 'rmsnorm' + ln_type = "rmsnorm" def target_func(x, gamma): return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon)) @@ -128,10 +150,12 @@ class TestDistributedLayernorm: output = y * gamma return jnp.mean(output) - (x, gamma, _), (x_pspec, g_pspec, _) = \ - self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) - collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, - data_shape, dtype) + (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs( + data_shape, mesh_resource, dtype, shard_weights + ) + collective_count_ref = self.generate_collectives_count_ref( + mesh_resource, ln_type, data_shape, dtype + ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast(mesh_resource=mesh_resource): @@ -140,14 +164,17 @@ class TestDistributedLayernorm: with warnings.catch_warnings(record=True) as warns: try: - compare_ops(target_func, - ref_func, [x_, gamma_], - collective_count_ref, - grad_args=(0, 1), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(x_pspec, g_pspec), - out_shardings=(None, (x_pspec, g_pspec))) + compare_ops( + target_func, + ref_func, + [x_, gamma_], + collective_count_ref, + grad_args=(0, 1), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(x_pspec, g_pspec), + out_shardings=(None, (x_pspec, g_pspec)), + ) except AssertionError as err: # RmsNorm should still produce the correct numerical result with # gamma/beta sharded. However, the collective count may not be the same diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index e61b013b8fd7aa3178953b83ad72d18ddb8ac906..38f7ec0d49f9408f08971cde0a57db460a329a7d 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -15,22 +15,23 @@ from transformer_engine.jax import fp8_autocast from transformer_engine.jax.flax import LayerNormMLP from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp from transformer_engine.jax.sharding import ( - HIDDEN_AXES, HIDDEN_TP_AXES, + HIDDEN_AXES, + HIDDEN_TP_AXES, BATCH_AXES, - SEQLEN_TP_AXES, SEQLEN_AXES, - W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES + SEQLEN_TP_AXES, + SEQLEN_AXES, + W_NO_SHARD_AXES, + W_FSDP_AXES, + W_TP_AXES, + W_JOINED_AXES, ) from transformer_engine.jax.sharding import MeshResource -from utils import ( - assert_allclose, - assert_tree_like_allclose, - is_devices_enough -) +from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough is_fp8_supported, reason = is_fp8_available() DTYPES = [jnp.bfloat16, jnp.float16] -INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in] +INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in] LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) @@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs(): configs = [] if is_devices_enough(2): configs.append( - [2, (1, 2), ('fsdp', 'tp'), - MeshResource(fsdp_resource='fsdp', tp_resource='tp')]) + [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] + ) if is_devices_enough(4): configs.append( - [4, (2, 2), ('fsdp', 'tp'), - MeshResource(fsdp_resource='fsdp', tp_resource='tp')]) + [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] + ) return configs @@ -64,10 +65,12 @@ class TestDistributedLayernormMLP: x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype) - k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), - dtype) / jnp.sqrt(hidden_in) - k2 = jax.random.normal(subkeys[2], - (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(INTERMEDIATE) + k1 = jax.random.normal( + subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype + ) / jnp.sqrt(hidden_in) + k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt( + INTERMEDIATE + ) if use_bias: b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype) b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype) @@ -90,15 +93,27 @@ class TestDistributedLayernormMLP: scale_list_1: List[jnp.ndarray], scale_list_2: List[jnp.ndarray], layernorm_type: str = "rmsnorm", - activation_type: Sequence[Union[str, Callable]] = ('gelu',), + activation_type: Sequence[Union[str, Callable]] = ("gelu",), use_bias: bool = True, multi_gpus: bool = False, ) -> jnp.ndarray: - fp8_meta_pkg1 = FP8MetaPackage(amax_list_1[0], scale_list_1[0], amax_list_1[1], - scale_list_1[1], amax_list_1[2], scale_list_1[2]) - fp8_meta_pkg2 = FP8MetaPackage(amax_list_2[0], scale_list_2[0], amax_list_2[1], - scale_list_2[1], amax_list_2[2], scale_list_2[2]) + fp8_meta_pkg1 = FP8MetaPackage( + amax_list_1[0], + scale_list_1[0], + amax_list_1[1], + scale_list_1[1], + amax_list_1[2], + scale_list_1[2], + ) + fp8_meta_pkg2 = FP8MetaPackage( + amax_list_2[0], + scale_list_2[0], + amax_list_2[1], + scale_list_2[1], + amax_list_2[2], + scale_list_2[2], + ) if multi_gpus: layernorm_input_axes = LAYERNORM_INPUT_AXES @@ -111,60 +126,68 @@ class TestDistributedLayernormMLP: # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2 return jnp.mean( - fused_layernorm_fp8_mlp(x, - ln_scale, - None, [kernel_1, kernel_2], [bias_1, bias_2], - [fp8_meta_pkg1, fp8_meta_pkg2], - layernorm_type, - layernorm_input_axes=layernorm_input_axes, - dot_1_input_axes=dot_1_input_axes, - dot_2_input_axes=dot_2_input_axes, - activation_type=activation_type, - use_bias=use_bias)) + fused_layernorm_fp8_mlp( + x, + ln_scale, + None, + [kernel_1, kernel_2], + [bias_1, bias_2], + [fp8_meta_pkg1, fp8_meta_pkg2], + layernorm_type, + layernorm_input_axes=layernorm_input_axes, + dot_1_input_axes=dot_1_input_axes, + dot_2_input_axes=dot_2_input_axes, + activation_type=activation_type, + use_bias=use_bias, + ) + ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) - @pytest.mark.parametrize('input_shape', INPUT_SHAPE) - @pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear')]) - @pytest.mark.parametrize('dtype', DTYPES) - @pytest.mark.parametrize('use_bias', [True, False]) - def test_layernorm_fp8_mlp_primitive(self, mesh_config, activation_type, use_bias, input_shape, - dtype): + @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs()) + @pytest.mark.parametrize("input_shape", INPUT_SHAPE) + @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest.mark.parametrize("dtype", DTYPES) + @pytest.mark.parametrize("use_bias", [True, False]) + def test_layernorm_fp8_mlp_primitive( + self, mesh_config, activation_type, use_bias, input_shape, dtype + ): device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config - layernorm_type = 'rmsnorm' + layernorm_type = "rmsnorm" fp8_amax_list_1 = [ jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), - jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32) + jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), ] fp8_amax_list_2 = [ jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), - jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32) + jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), ] fp8_scale_list_1 = [ jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32), - jnp.ones((1,), jnp.float32) + jnp.ones((1,), jnp.float32), ] fp8_scale_list_2 = [ jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32), - jnp.ones((1,), jnp.float32) + jnp.ones((1,), jnp.float32), ] - inputs = [x, gamma, k1, k2, b1, b2] = \ - self.generate_inputs(input_shape, activation_type, use_bias, dtype) + inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs( + input_shape, activation_type, use_bias, dtype + ) inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2] static_inputs = [layernorm_type, activation_type, use_bias] - value_and_grad_func = jax.value_and_grad(self.layernorm_fp8_mlp_prim_func, - argnums=range(len(inputs))) + value_and_grad_func = jax.value_and_grad( + self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) + ) # Single GPU - single_jitter = jax.jit(value_and_grad_func, - static_argnums=range(len(inputs), - len(static_inputs) + len(inputs))) + single_jitter = jax.jit( + value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs)) + ) with fp8_autocast(enabled=True): single_fwd, single_grads = single_jitter(*inputs, *static_inputs) @@ -172,12 +195,12 @@ class TestDistributedLayernormMLP: devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource): - k1_sharding = NamedSharding(mesh, PartitionSpec('fsdp', None, 'tp')) - k2_sharding = NamedSharding(mesh, PartitionSpec('tp', 'fsdp')) + k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) + k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) k1_ = jax.device_put(k1, k1_sharding) k2_ = jax.device_put(k2, k2_sharding) if use_bias: - b1_sharding = NamedSharding(mesh, PartitionSpec(None, 'tp')) + b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) b1_ = jax.device_put(b1, b1_sharding) else: b1_sharding = b1_ = None @@ -186,17 +209,29 @@ class TestDistributedLayernormMLP: # Position ref for sharding pspec lists # x, gamma, k1, k2, b1, # b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv - in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, - None, None) - out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, - None, None, None)) - - multi_jitter = jax.jit(value_and_grad_func, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=range(len(multi_inputs), - len(static_inputs) + len(multi_inputs) + - 1)) # +1 for multi_gpus + in_shardings = ( + None, + None, + k1_sharding, + k2_sharding, + b1_sharding, + None, + None, + None, + None, + None, + ) + out_shardings = ( + None, + (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None), + ) + + multi_jitter = jax.jit( + value_and_grad_func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1), + ) # +1 for multi_gpus multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) @@ -206,97 +241,96 @@ class TestDistributedLayernormMLP: if isinstance(multi_grads[i], list): assert isinstance(single_grads[i], list) for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): - assert_allclose(m_grad, - s_grad, - dtype=dtype, - err_msg=f'multi_grads[{i}] is not close') + assert_allclose( + m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" + ) else: - assert_allclose(multi_grads[i], - single_grads[i], - dtype=dtype, - err_msg=f'multi_grads[{i}] is not close') - - def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias, input_shape, dtype, - use_fp8): + assert_allclose( + multi_grads[i], + single_grads[i], + dtype=dtype, + err_msg=f"multi_grads[{i}] is not close", + ) + + def _test_layernorm_mlp( + self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8 + ): batch, seqlen, hidden_in = input_shape - layernorm_type = 'rmsnorm' + layernorm_type = "rmsnorm" rng = jax.random.PRNGKey(0) subkeys = jax.random.split(rng, 2) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) - init_rngs = {'params': subkeys[1]} + init_rngs = {"params": subkeys[1]} # Single GPUs with fp8_autocast(enabled=use_fp8): ln_mlp_single = LayerNormMLP( layernorm_type=layernorm_type, - transpose_batch_sequence=False, # input: [batch, seqlen, hidden] + transpose_batch_sequence=False, # input: [batch, seqlen, hidden] intermediate_dim=INTERMEDIATE, activations=activation_type, dtype=dtype, use_bias=use_bias, ) params_single = ln_mlp_single.init(init_rngs, x) - mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single, - x, - deterministic=True) + mlp_out_single, ln_out_single = ln_mlp_single.apply( + params_single, x, deterministic=True + ) # Multi GPUs device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource): - ln_mlp_sharded = LayerNormMLP(layernorm_type=layernorm_type, - transpose_batch_sequence=False, - intermediate_dim=INTERMEDIATE, - activations=activation_type, - dtype=dtype, - scale_axes=(W_NO_SHARD_AXES,), - ln_bias_axes=(W_NO_SHARD_AXES,), - kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), - kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), - use_bias=use_bias, - bias_axes_1=(W_JOINED_AXES, W_TP_AXES), - bias_axes_2=(W_NO_SHARD_AXES,), - layernorm_input_axes=LAYERNORM_INPUT_AXES, - dot_1_input_axes=DOT_1_INPUT_AXES, - dot_2_input_axes=DOT_2_INPUT_AXES, - name='mlp') + ln_mlp_sharded = LayerNormMLP( + layernorm_type=layernorm_type, + transpose_batch_sequence=False, + intermediate_dim=INTERMEDIATE, + activations=activation_type, + dtype=dtype, + scale_axes=(W_NO_SHARD_AXES,), + ln_bias_axes=(W_NO_SHARD_AXES,), + kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), + kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), + use_bias=use_bias, + bias_axes_1=(W_JOINED_AXES, W_TP_AXES), + bias_axes_2=(W_NO_SHARD_AXES,), + layernorm_input_axes=LAYERNORM_INPUT_AXES, + dot_1_input_axes=DOT_1_INPUT_AXES, + dot_2_input_axes=DOT_2_INPUT_AXES, + name="mlp", + ) params_sharded = ln_mlp_sharded.init(init_rngs, x) - mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded, - x, - deterministic=True) + mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( + params_sharded, x, deterministic=True + ) # Make sure params values are the same - assert_tree_like_allclose(params_sharded['params'], params_single['params']) + assert_tree_like_allclose(params_sharded["params"], params_single["params"]) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype) - @pytest.mark.parametrize('input_shape', INPUT_SHAPE) - @pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) - @pytest.mark.parametrize('activation_type', [("gelu",), ('silu', 'linear'), ('gelu', 'gelu')]) - @pytest.mark.parametrize('dtype', DTYPES) - @pytest.mark.parametrize('use_bias', [True, False]) + @pytest.mark.parametrize("input_shape", INPUT_SHAPE) + @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs()) + @pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")]) + @pytest.mark.parametrize("dtype", DTYPES) + @pytest.mark.parametrize("use_bias", [True, False]) def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype): - self._test_layernorm_mlp(mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - use_fp8=False) + self._test_layernorm_mlp( + mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False + ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) - @pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear'), ('gelu', 'gelu')]) - @pytest.mark.parametrize('use_bias', [True, False]) - @pytest.mark.parametrize('input_shape', INPUT_SHAPE) - @pytest.mark.parametrize('dtype', DTYPES) - def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, - dtype): - self._test_layernorm_mlp(mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - use_fp8=True) + @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs()) + @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")]) + @pytest.mark.parametrize("use_bias", [True, False]) + @pytest.mark.parametrize("input_shape", INPUT_SHAPE) + @pytest.mark.parametrize("dtype", DTYPES) + def test_layernorm_fp8_mlp_layer( + self, mesh_config, activation_type, use_bias, input_shape, dtype + ): + self._test_layernorm_mlp( + mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True + ) diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index daa3a9111dd9cd77f88971873a4f4bc341cb7d4c..0ed6b84fd525e8c9c1518410179ea3ed97645976 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -25,7 +25,7 @@ class TestDistributedSoftmax: def generate_collectives_count_ref(self): # for loss - all_reduce_loss_bytes = 4 # 1 * FP32 + all_reduce_loss_bytes = 4 # 1 * FP32 return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding): @@ -38,49 +38,65 @@ class TestDistributedSoftmax: mask = make_self_mask(batch, sqelen) if not bad_sharding: - x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource, - None, None) + x_pspec = PartitionSpec( + mesh_resource.dp_resource, mesh_resource.tp_resource, None, None + ) else: - x_pspec = PartitionSpec(mesh_resource.dp_resource, None, - None, mesh_resource.tp_resource) + x_pspec = PartitionSpec( + mesh_resource.dp_resource, None, None, mesh_resource.tp_resource + ) mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) return (x, mask), (x_pspec, mask_pspec) @staticmethod def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED): - return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type)) + return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type)) @staticmethod def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16): bias = None if mask is not None: - bias = jax.lax.select(mask > 0, - jnp.full(mask.shape, -1e10).astype(dtype), - jnp.full(mask.shape, 0.).astype(dtype)) + bias = jax.lax.select( + mask > 0, + jnp.full(mask.shape, -1e10).astype(dtype), + jnp.full(mask.shape, 0.0).astype(dtype), + ) if bias is not None: x = x + bias.astype(dtype) output = jax.nn.softmax(x * scale_factor) return jnp.mean(output) - @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) - @pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]]) + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]]) @pytest.mark.parametrize( - 'softmax_type', - [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED]) - @pytest.mark.parametrize('scale_factor', [1.0, 3.0]) - @pytest.mark.parametrize('dtype', DTYPES) - @pytest.mark.parametrize('bad_sharding', [False, True]) - def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, - softmax_type, scale_factor, dtype, bad_sharding): - - target_func = partial(self.target_func, - scale_factor=scale_factor, - softmax_type=softmax_type) + "softmax_type", + [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], + ) + @pytest.mark.parametrize("scale_factor", [1.0, 3.0]) + @pytest.mark.parametrize("dtype", DTYPES) + @pytest.mark.parametrize("bad_sharding", [False, True]) + def test_softmax( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + softmax_type, + scale_factor, + dtype, + bad_sharding, + ): + + target_func = partial( + self.target_func, scale_factor=scale_factor, softmax_type=softmax_type + ) ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) - (x, mask), (x_pspec, mask_pspec) = \ - self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype, bad_sharding) + (x, mask), (x_pspec, mask_pspec) = self.generate_inputs( + data_shape, mesh_resource, softmax_type, dtype, bad_sharding + ) collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) @@ -90,14 +106,17 @@ class TestDistributedSoftmax: with warnings.catch_warnings(record=True) as warns: try: - compare_ops(target_func, - ref_func, [x_, mask_], - collective_count_ref, - grad_args=(0,), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(x_pspec, mask_pspec), - out_shardings=(None, (x_pspec,))) + compare_ops( + target_func, + ref_func, + [x_, mask_], + collective_count_ref, + grad_args=(0,), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(x_pspec, mask_pspec), + out_shardings=(None, (x_pspec,)), + ) except AssertionError as err: # Softmax should still produce the correct numerical result with # bad sharding. However, the collective count may not be the same diff --git a/tests/jax/test_functions.py b/tests/jax/test_functions.py index aaa6be77acf589b035c77f0b8c7a77939c9f7c46..d6da307fd3d9400d05965a890c89018b72579828 100644 --- a/tests/jax/test_functions.py +++ b/tests/jax/test_functions.py @@ -20,12 +20,14 @@ class TestLoRA: out = jnp.einsum(pattern, x, la, lb) return out * scale - @pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)]) - @pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16]) - @pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'), - ((-1,), (3, 1024), '...h,hkr,krz->...kz')]) - @pytest.mark.parametrize('rank', [32, 16]) - @pytest.mark.parametrize('alpha', [None, 4, 8]) + @pytest.mark.parametrize("shape", [(32, 1024), (32, 128, 1024)]) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + @pytest.mark.parametrize( + "axis_features_pattern", + [((-1,), (1024,), "...h,hr,rk->...k"), ((-1,), (3, 1024), "...h,hkr,krz->...kz")], + ) + @pytest.mark.parametrize("rank", [32, 16]) + @pytest.mark.parametrize("alpha", [None, 4, 8]) def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha): axis, features, pattern = axis_features_pattern axis = _normalize_axes(axis, len(shape)) @@ -49,16 +51,20 @@ class TestLoRA: assert_allclose(out_target, out_ref, dtype=dtype) - @pytest.mark.parametrize('scope_ref_assert', - [('none', LoRAScope(False, False, False), False), - ('all', LoRAScope(True, True, True), False), - ('qkv_proj', LoRAScope(True, False, False), False), - ('output_proj', LoRAScope(False, True, False), False), - ('mlp', LoRAScope(False, False, True), False), - ('exclude_qkv_proj', LoRAScope(False, True, True), False), - ('exclude_output_proj', LoRAScope(True, False, True), False), - ('exclude_mlp', LoRAScope(True, True, False), False), - ('messing_up', LoRAScope(), True)]) + @pytest.mark.parametrize( + "scope_ref_assert", + [ + ("none", LoRAScope(False, False, False), False), + ("all", LoRAScope(True, True, True), False), + ("qkv_proj", LoRAScope(True, False, False), False), + ("output_proj", LoRAScope(False, True, False), False), + ("mlp", LoRAScope(False, False, True), False), + ("exclude_qkv_proj", LoRAScope(False, True, True), False), + ("exclude_output_proj", LoRAScope(True, False, True), False), + ("exclude_mlp", LoRAScope(True, True, False), False), + ("messing_up", LoRAScope(), True), + ], + ) def test_lora_scope_generator(self, scope_ref_assert): scope, reference, need_assert = scope_ref_assert try: diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 3bb7520ca1a995b5c26f0985950e1e4ce89c4ede..3390c36426f13f843953bafc3f007ebe05a4f11f 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -24,7 +24,7 @@ from transformer_engine.jax.attention import ( QKVLayout, fused_attn_qkvpacked, fused_attn_kvpacked, - fused_attn + fused_attn, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper @@ -33,7 +33,7 @@ from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend from utils import assert_allclose -@pytest.fixture(autouse=True, scope='module') +@pytest.fixture(autouse=True, scope="module") def init(): """ WAR for CUDA uninitialize error @@ -43,10 +43,18 @@ def init(): yield -def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike, - bias: ArrayLike, mask: ArrayLike, deterministic: bool, - scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike, - dtype: DTypeLike) -> Array: +def general_dot_product_attention( + query: ArrayLike, + key: ArrayLike, + value: ArrayLike, + bias: ArrayLike, + mask: ArrayLike, + deterministic: bool, + scale_factor: float, + dropout_rate: float, + dropout_rng: ArrayLike, + dtype: DTypeLike, +) -> Array: """ Similar to flax.linen.dot_product_attention but with GQA support """ @@ -59,7 +67,7 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array num_groups = h_q // h_kv grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d)) # logits with shape (b, h_kv, num_groups, s_q, s_kv) - logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key) + logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key) if bias is not None: # reshape logits without groups @@ -76,13 +84,13 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array softmax_out = jax.nn.softmax(logits).astype(dtype) - if not deterministic and dropout_rate > 0.: + if not deterministic and dropout_rate > 0.0: keep_prob = 1.0 - dropout_rate keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape) multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) softmax_out = softmax_out * multiplier - context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value) + context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value) context = jnp.reshape(context, query.shape) return context @@ -105,6 +113,7 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0) return combine_masks(inv_causal_mask, inv_padding_mask) + def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array: """ Create attention mask based on mask type. A `True` value in the mask means @@ -118,23 +127,26 @@ def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskT mask = jnp.logical_not(inv_mask) return mask + def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs): """ JAX native dot product attention implementation """ - attn_mask_type = kwargs['attn_mask_type'] + attn_mask_type = kwargs["attn_mask_type"] mask = make_mask(q_token, kv_token, attn_mask_type) - output = general_dot_product_attention(query, - key, - value, - bias=bias, - mask=mask, - deterministic=not kwargs['is_training'], - scale_factor=kwargs['scaling_factor'], - dropout_rate=kwargs['dropout_probability'], - dropout_rng=dropout_rng, - dtype=jnp.float32) + output = general_dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + deterministic=not kwargs["is_training"], + scale_factor=kwargs["scaling_factor"], + dropout_rate=kwargs["dropout_probability"], + dropout_rng=dropout_rng, + dtype=jnp.float32, + ) return output.astype(query.dtype) @@ -142,10 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng """ TE customcall dot product attention implementation """ - attn_mask_type = kwargs['attn_mask_type'] + attn_mask_type = kwargs["attn_mask_type"] mask = make_mask(q_token, kv_token, attn_mask_type) - qkv_layout = kwargs.pop('qkv_layout') + qkv_layout = kwargs.pop("qkv_layout") match qkv_layout: case QKVLayout.BS3HD: query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) @@ -154,11 +166,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng case QKVLayout.BSHD_BS2HD: key, value = map(partial(jnp.expand_dims, axis=-3), [key, value]) kv = jnp.concatenate((key, value), axis=-3) - return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, - **kwargs).astype(query.dtype) + return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, **kwargs).astype( + query.dtype + ) case QKVLayout.BSHD_BSHD_BSHD: - return fused_attn(query, key, value, bias, mask, dropout_rng, - **kwargs).astype(query.dtype) + return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype( + query.dtype + ) class BiasShape(Enum): @@ -166,10 +180,10 @@ class BiasShape(Enum): Enum class to represent the different bias shapes used in the fused attention. """ - BIAS_1HSS = '1HSS' - BIAS_B1SS = 'B1SS' - BIAS_BHSS = 'BHSS' - BIAS_11SS = '11SS' + BIAS_1HSS = "1HSS" + BIAS_B1SS = "B1SS" + BIAS_BHSS = "BHSS" + BIAS_11SS = "11SS" @dataclass @@ -177,6 +191,7 @@ class FusedAttnRunner: """ Fused attention runner """ + batch_size: int max_seqlen_q: int max_seqlen_kv: int @@ -198,21 +213,33 @@ class FusedAttnRunner: if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") - self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value, - self.attn_bias_type.value, self.attn_mask_type.value, - self.dropout_prob, self.num_heads_q, self.num_heads_kv, - self.max_seqlen_q, self.max_seqlen_kv, - self.head_dim).get_fused_attn_backend() + self.backend = FusedAttnHelper( + self.dtype, + self.dtype, + self.qkv_layout.value, + self.attn_bias_type.value, + self.attn_mask_type.value, + self.dropout_prob, + self.num_heads_q, + self.num_heads_kv, + self.max_seqlen_q, + self.max_seqlen_kv, + self.head_dim, + ).get_fused_attn_backend() if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: pytest.skip("Unsupported inputs combination or device compute capability.") if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: - pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " - "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.") + pytest.skip( + "B1SS, BHSS and 11SS bias shapes are only supported for " + "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK." + ) elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " - "the F16_arbitrary_seqlen backend.") + pytest.skip( + "B1SS, BHSS and 11SS bias shapes are only supported for " + "the F16_arbitrary_seqlen backend." + ) def _setup_inputs(self): self._check_configs() @@ -235,24 +262,25 @@ class FusedAttnRunner: else: pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") - self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.) - self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.) - self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.) + self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0) + self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0) + self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) if self.attn_bias_type != AttnBiasType.NO_BIAS: if self.bias_shape == BiasShape.BIAS_1HSS: - self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.) + self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0) else: # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for # an arbitrary mask where (True/False -> 0/-Inf) - cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15. + cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0) self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype) max_id = min(self.max_seqlen_q, self.max_seqlen_kv) - seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences + seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist() for i in range(1, len(seq_id)): - self.bias = \ - self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.) + self.bias = self.bias.at[ + :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i] + ].set(0.0) else: self.bias = None @@ -271,7 +299,7 @@ class FusedAttnRunner: self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) self.dropout_rng = dropout_key if self.dropout_prob > 0 else None - self.scaling_factor = 1. / sqrt(self.head_dim) + self.scaling_factor = 1.0 / sqrt(self.head_dim) def test_forward(self): """ @@ -281,19 +309,19 @@ class FusedAttnRunner: args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] kwargs = { - 'attn_bias_type': self.attn_bias_type, - 'attn_mask_type': self.attn_mask_type, - 'scaling_factor': self.scaling_factor, - 'dropout_probability': self.dropout_prob, - 'is_training': self.is_training, - 'qkv_layout': self.qkv_layout, + "attn_bias_type": self.attn_bias_type, + "attn_mask_type": self.attn_mask_type, + "scaling_factor": self.scaling_factor, + "dropout_probability": self.dropout_prob, + "is_training": self.is_training, + "qkv_layout": self.qkv_layout, } # Convert the outputs to float32 for the elementwise comparison primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32) reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32) - if self.is_training and self.dropout_prob > 0.: + if self.is_training and self.dropout_prob > 0.0: return primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1) @@ -322,12 +350,12 @@ class FusedAttnRunner: args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] kwargs = { - 'attn_bias_type': self.attn_bias_type, - 'attn_mask_type': self.attn_mask_type, - 'scaling_factor': self.scaling_factor, - 'dropout_probability': self.dropout_prob, - 'is_training': self.is_training, - 'qkv_layout': self.qkv_layout, + "attn_bias_type": self.attn_bias_type, + "attn_mask_type": self.attn_mask_type, + "scaling_factor": self.scaling_factor, + "dropout_probability": self.dropout_prob, + "is_training": self.is_training, + "qkv_layout": self.qkv_layout, } # We can compute dBias only for the [1, h, s, s] layout @@ -336,12 +364,18 @@ class FusedAttnRunner: # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( value_and_grad( - lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args, - **kwargs), arg_nums)) + lambda q, k, v, bias, *args: grad_func( + customcall_fused_dpa, q, k, v, bias, *args, **kwargs + ), + arg_nums, + ) + ) jitted_reference = jit( value_and_grad( lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs), - arg_nums)) + arg_nums, + ) + ) primitive_out, primitive_dgrad = jitted_primitive(*args) reference_out, reference_dgrad = jitted_reference(*args) @@ -350,9 +384,9 @@ class FusedAttnRunner: if self.dropout_prob > 0.0: return - assert_allclose(primitive_out.astype(jnp.float32), - reference_out.astype(jnp.float32), - dtype=self.dtype) + assert_allclose( + primitive_out.astype(jnp.float32), reference_out.astype(jnp.float32), dtype=self.dtype + ) def check_dqkv(primitive, reference, valid_len): primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1) @@ -374,81 +408,158 @@ class FusedAttnRunner: primitive_dbias = jnp.float32(primitive_dgrad[3]) reference_dbias = jnp.float32(reference_dgrad[3]) - assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], - jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, - self.valid_len_kv:]), - dtype=self.dtype) + assert_allclose( + primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :], + jnp.zeros_like(primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :]), + dtype=self.dtype, + ) # dbias padded part - assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], - reference_dbias[..., self.valid_len_q:, self.valid_len_kv:], - dtype=self.dtype) + assert_allclose( + primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :], + reference_dbias[..., self.valid_len_q :, self.valid_len_kv :], + dtype=self.dtype, + ) # dbias valid part - assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv], - reference_dbias[..., :self.valid_len_q, :self.valid_len_kv], - dtype=self.dtype) - - -@pytest.mark.parametrize('attn_bias_type, bias_shape', [ - pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'), -]) -@pytest.mark.parametrize('attn_mask_type', [ - pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'), - pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'), - pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'), - pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'), -]) -@pytest.mark.parametrize('qkv_layout', [ - pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'), - pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'), -]) -@pytest.mark.parametrize('dtype', [ - pytest.param(jnp.bfloat16, id="BF16"), - pytest.param(jnp.float16, id="FP16"), -]) -@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [ - pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), - pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), - pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), - pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), - pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), - pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'), -]) -@pytest.mark.parametrize('dropout_prob', [ - pytest.param(0.0, id="DROP_0.0"), - pytest.param(0.1, id="DROP_0.1"), -]) + assert_allclose( + primitive_dbias[..., : self.valid_len_q, : self.valid_len_kv], + reference_dbias[..., : self.valid_len_q, : self.valid_len_kv], + dtype=self.dtype, + ) + + +@pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"), + ], +) +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"), + ], +) +@pytest.mark.parametrize( + "qkv_layout", + [ + pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), + pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param(jnp.bfloat16, id="BF16"), + pytest.param(jnp.float16, id="FP16"), + ], +) +@pytest.mark.parametrize( + "b, s_q, s_kv, h_q, h_kv, d", + [ + pytest.param(32, 128, 128, 16, 16, 64, id="32-128-128-16-16-64-SELF"), + pytest.param(4, 2048, 2048, 12, 12, 64, id="4-2048-2048-12-12-64-SELF"), + pytest.param(32, 512, 128, 16, 16, 64, id="32-512-128-16-16-64-CROSS"), + pytest.param(4, 2048, 1024, 12, 12, 64, id="4-2048-1048-12-12-64-CROSS"), + pytest.param(32, 128, 128, 16, 8, 64, id="32-128-128-16-8-64-GQA"), + pytest.param(4, 2048, 2048, 12, 6, 64, id="4-2048-2048-12-6-64-GQA"), + ], +) +@pytest.mark.parametrize( + "dropout_prob", + [ + pytest.param(0.0, id="DROP_0.0"), + pytest.param(0.1, id="DROP_0.1"), + ], +) class TestFusedAttn: """ Fused attention tester """ @staticmethod - @pytest.mark.parametrize('is_training', [ - pytest.param(True, id='TRAINING'), - pytest.param(False, id='INFERENCE'), - ]) - def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, - dtype, is_training, qkv_layout, bias_shape): + @pytest.mark.parametrize( + "is_training", + [ + pytest.param(True, id="TRAINING"), + pytest.param(False, id="INFERENCE"), + ], + ) + def test_forward( + b, + s_q, + s_kv, + h_q, + h_kv, + d, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + ): """ Test forward with parameterized configs """ - runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, - dropout_prob, dtype, is_training, qkv_layout, bias_shape) + runner = FusedAttnRunner( + b, + s_q, + s_kv, + h_q, + h_kv, + d, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + ) runner.test_forward() @staticmethod - def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, - dtype, qkv_layout, bias_shape): + def test_backward( + b, + s_q, + s_kv, + h_q, + h_kv, + d, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + qkv_layout, + bias_shape, + ): """ Test backward with parameterized configs """ - runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, - dropout_prob, dtype, True, qkv_layout, bias_shape) + runner = FusedAttnRunner( + b, + s_q, + s_kv, + h_q, + h_kv, + d, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + True, + qkv_layout, + bias_shape, + ) runner.test_backward() diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index 90f558744368281c5ffa6e87f1de054e951d301f..3a0add0a38cfdaf4736349963681dec8d4163b79 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -27,21 +27,28 @@ class TestFP8Helper(unittest.TestCase): fp8_format = FP8Format.E4M3 amax_history_len = 10 - FP8Helper.initialize(margin=margin, - fp8_format=fp8_format, - amax_history_len=amax_history_len) + FP8Helper.initialize( + margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len + ) self.assertEqual( - FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}" - f" but got {FP8Helper.MARGIN}.") + FP8Helper.MARGIN, + margin, + f"FP8Helper.MARGIN initialization failed, should be {margin}" + f" but got {FP8Helper.MARGIN}.", + ) self.assertEqual( - FP8Helper.FP8_FORMAT, fp8_format, + FP8Helper.FP8_FORMAT, + fp8_format, f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}" - f" but got {FP8Helper.FP8_FORMAT}.") + f" but got {FP8Helper.FP8_FORMAT}.", + ) self.assertEqual( - FP8Helper.AMAX_HISTORY_LEN, amax_history_len, + FP8Helper.AMAX_HISTORY_LEN, + amax_history_len, f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}" - f" but got {FP8Helper.AMAX_HISTORY_LEN}.") + f" but got {FP8Helper.AMAX_HISTORY_LEN}.", + ) FP8Helper.finalize() @@ -77,7 +84,7 @@ class TestFP8Functions(unittest.TestCase): @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast(self): - FP8Helper.finalize() # Ensure the testing not affect by previous tests. + FP8Helper.finalize() # Ensure the testing not affect by previous tests. self._check_defult_state() with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()): @@ -102,21 +109,21 @@ class TestFP8Functions(unittest.TestCase): @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_with_sharding_resource(self): - FP8Helper.finalize() # Ensure the testing not affect by previous tests. + FP8Helper.finalize() # Ensure the testing not affect by previous tests. self._check_defult_state() ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) mesh_s = ( (MeshResource(None, None)), - (MeshResource('dp', None)), - (MeshResource(None, 'tp')), - (MeshResource('dp', 'tp')), + (MeshResource("dp", None)), + (MeshResource(None, "tp")), + (MeshResource("dp", "tp")), ) # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme mesh_shape = (1, 1) devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape) - with jax.sharding.Mesh(devices, ('dp', 'tp')): + with jax.sharding.Mesh(devices, ("dp", "tp")): for sr in mesh_s: with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr): self.assertTrue(FP8Helper.is_fp8_enabled()) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index b8f52c91a2d86fe7f16e8a375a0fd4efdb532dc5..fa04382d598c6054f07d39075658ab5d3f70faa0 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -22,7 +22,7 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available is_fp8_supported, reason = is_fp8_available() -@pytest.fixture(autouse=True, scope='function') +@pytest.fixture(autouse=True, scope="function") def enable_fused_attn(): """Enable fused attention""" os.environ["NVTE_FUSED_ATTN"] = "1" @@ -30,9 +30,9 @@ def enable_fused_attn(): del os.environ["NVTE_FUSED_ATTN"] -DATA_SHAPE = [ # (batch, seqlen, emb_dim) - pytest.param((32, 128, 1024), id='32-128-1024'), - pytest.param((32, 512, 1024), id='32-512-1024'), +DATA_SHAPE = [ # (batch, seqlen, emb_dim) + pytest.param((32, 128, 1024), id="32-128-1024"), + pytest.param((32, 512, 1024), id="32-512-1024"), ] DTYPE = [jnp.float32, jnp.bfloat16] FP8_FORMATS = [Format.E4M3, Format.HYBRID] @@ -69,123 +69,138 @@ BASE_ATTRS = { _KEY_OF_ATTENTION_DROPOUT: 0, _KEY_OF_INTERMEDIATE_DROPOUT: 0, _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal", - _KEY_OF_LAYERNORM_TYPE: 'layernorm', + _KEY_OF_LAYERNORM_TYPE: "layernorm", } -ATTRS = [{}, { - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', -}, { - _KEY_OF_ZERO_CENTERED_GAMMA: True, - _KEY_OF_LAYERNORM_EPS: 1e-2, -}, { - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_RESIDUAL_POST_LAYERNORM: True -}, { - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_OUTPUT_LAYERNORM: True -}, { - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_RESIDUAL_POST_LAYERNORM: True, - _KEY_OF_OUTPUT_LAYERNORM: True -}, { - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_DROP_PATH: 0.1 -}, { - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_FUSE_QKV_PARAMS: False -}, { - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), -}, { - _KEY_OF_SCALE_ATTN_LOGITS: True, - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_HIDDEN_DROPOUT: 0.8, - _KEY_OF_INTERMEDIATE_DROPOUT: 0.5, - _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), - _KEY_OF_USE_BIAS: True, -}, { - _KEY_OF_TRANSPOSE_BS: False, - _KEY_OF_SCALE_ATTN_LOGITS: True, - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), -}, { - _KEY_OF_NUM_HEADS: 8, - _KEY_OF_NUM_GQA_GROUPS: 4, - _KEY_OF_TRANSPOSE_BS: False, - _KEY_OF_SCALE_ATTN_LOGITS: True, - _KEY_OF_MLP_ACTIVATIONS: ('gelu',), - _KEY_OF_USE_BIAS: True, -}, { - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), -}, { - _KEY_OF_SCALE_ATTN_LOGITS: True, - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_HIDDEN_DROPOUT: 0.8, - _KEY_OF_INTERMEDIATE_DROPOUT: 0.5, - _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), - _KEY_OF_USE_BIAS: True, -}, { - _KEY_OF_TRANSPOSE_BS: False, - _KEY_OF_SCALE_ATTN_LOGITS: True, - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), -}, { - _KEY_OF_NUM_HEADS: 8, - _KEY_OF_NUM_GQA_GROUPS: 4, - _KEY_OF_TRANSPOSE_BS: False, - _KEY_OF_SCALE_ATTN_LOGITS: True, - _KEY_OF_LAYERNORM_TYPE: 'layernorm', - _KEY_OF_MLP_ACTIVATIONS: (('silu',)), - _KEY_OF_USE_BIAS: True, -}, { - _KEY_OF_TRANSPOSE_BS: False, - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_NUM_GQA_GROUPS: 1, - _KEY_OF_ENABLE_ROPE: True, - _KEY_OF_ROPE_GROUP_METHOD: "consecutive", - _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, -}, { - _KEY_OF_TRANSPOSE_BS: True, - _KEY_OF_ENABLE_ROPE: True, - _KEY_OF_ROPE_GROUP_METHOD: "consecutive", - _KEY_OF_USE_BIAS: True, -}, { - _KEY_OF_TRANSPOSE_BS: False, - _KEY_OF_LAYERNORM_TYPE: 'layernorm', - _KEY_OF_NUM_GQA_GROUPS: 2, - _KEY_OF_ENABLE_ROPE: True, - _KEY_OF_ROPE_GROUP_METHOD: "alternate", - _KEY_OF_USE_BIAS: True, - _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, -}, { - _KEY_OF_TRANSPOSE_BS: True, - _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_ENABLE_ROPE: True, - _KEY_OF_ROPE_GROUP_METHOD: "alternate", - _KEY_OF_USE_BIAS: True, -}, { - _KEY_OF_HIDDEN_DROPOUT: 0.3, - _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,), - _KEY_OF_INTERMEDIATE_DROPOUT: 0.5, - _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,), -}, { - _KEY_OF_SELF_ATTN_MASK_TYPE: "padding", - _KEY_OF_USE_BIAS: True, -}, { - _KEY_OF_RELATIVE_EMBEDDING: False, - _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias", -}, { - _KEY_OF_ATTENTION_DROPOUT: 0.3, -}, { - _KEY_OF_MLP_ACTIVATIONS: (('relu', 'relu')), -}] +ATTRS = [ + {}, + { + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + }, + { + _KEY_OF_ZERO_CENTERED_GAMMA: True, + _KEY_OF_LAYERNORM_EPS: 1e-2, + }, + {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True}, + {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True}, + { + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + _KEY_OF_RESIDUAL_POST_LAYERNORM: True, + _KEY_OF_OUTPUT_LAYERNORM: True, + }, + {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1}, + {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False}, + { + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"), + }, + { + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + _KEY_OF_HIDDEN_DROPOUT: 0.8, + _KEY_OF_INTERMEDIATE_DROPOUT: 0.5, + _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"), + _KEY_OF_USE_BIAS: True, + }, + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"), + }, + { + _KEY_OF_NUM_HEADS: 8, + _KEY_OF_NUM_GQA_GROUPS: 4, + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_MLP_ACTIVATIONS: ("gelu",), + _KEY_OF_USE_BIAS: True, + }, + { + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")), + }, + { + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + _KEY_OF_HIDDEN_DROPOUT: 0.8, + _KEY_OF_INTERMEDIATE_DROPOUT: 0.5, + _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")), + _KEY_OF_USE_BIAS: True, + }, + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")), + }, + { + _KEY_OF_NUM_HEADS: 8, + _KEY_OF_NUM_GQA_GROUPS: 4, + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_LAYERNORM_TYPE: "layernorm", + _KEY_OF_MLP_ACTIVATIONS: (("silu",)), + _KEY_OF_USE_BIAS: True, + }, + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + _KEY_OF_NUM_GQA_GROUPS: 1, + _KEY_OF_ENABLE_ROPE: True, + _KEY_OF_ROPE_GROUP_METHOD: "consecutive", + _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, + }, + { + _KEY_OF_TRANSPOSE_BS: True, + _KEY_OF_ENABLE_ROPE: True, + _KEY_OF_ROPE_GROUP_METHOD: "consecutive", + _KEY_OF_USE_BIAS: True, + }, + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_LAYERNORM_TYPE: "layernorm", + _KEY_OF_NUM_GQA_GROUPS: 2, + _KEY_OF_ENABLE_ROPE: True, + _KEY_OF_ROPE_GROUP_METHOD: "alternate", + _KEY_OF_USE_BIAS: True, + _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, + }, + { + _KEY_OF_TRANSPOSE_BS: True, + _KEY_OF_LAYERNORM_TYPE: "rmsnorm", + _KEY_OF_ENABLE_ROPE: True, + _KEY_OF_ROPE_GROUP_METHOD: "alternate", + _KEY_OF_USE_BIAS: True, + }, + { + _KEY_OF_HIDDEN_DROPOUT: 0.3, + _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,), + _KEY_OF_INTERMEDIATE_DROPOUT: 0.5, + _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,), + }, + { + _KEY_OF_SELF_ATTN_MASK_TYPE: "padding", + _KEY_OF_USE_BIAS: True, + }, + { + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias", + }, + { + _KEY_OF_ATTENTION_DROPOUT: 0.3, + }, + { + _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")), + }, +] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] class BaseRunner: """Base runner to define forward and backward tests""" + layer_type: TransformerLayerType = None reference_layer: flax.linen.Module = None transformations: Dict[str, str] = None @@ -194,24 +209,24 @@ class BaseRunner: self.attrs = attrs self._generate_test_rngs() # Disable fused attention for attention dropout because the different dropout impl - if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv('NVTE_FUSED_ATTN'): - os.environ['NVTE_FUSED_ATTN'] = "0" + if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"): + os.environ["NVTE_FUSED_ATTN"] = "0" def _generate_test_rngs(self): root_rng = jax.random.PRNGKey(0) params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3) - self.init_rng = {'params': params_rng, 'dropout': init_dropout_rng} - self.apply_rng = {'dropout': apply_dropout_rng} + self.init_rng = {"params": params_rng, "dropout": init_dropout_rng} + self.apply_rng = {"dropout": apply_dropout_rng} def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs): layer = layer_cls() variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs) - others, params = flax.core.pop(variables, 'params') + others, params = flax.core.pop(variables, "params") del variables return layer, params, others def _loss_fn(self, diff_xs, no_diff_xs, params, others, model): - variables = {'params': params, **others} + variables = {"params": params, **others} output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng) return jnp.mean(output, dtype=jnp.float32).astype(output.dtype) @@ -259,15 +274,18 @@ class BaseRunner: ) _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME) test_others = FP8Helper.update_collections( - {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others) + {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others + ) del tmp_grad, fp8_meta_grad grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False) - ref_out, (ref_dgrads, ref_wgrads) = grad_fn(inputs, ref_masks, ref_params, ref_others, - ref_layer) - test_out, (test_dgrads, test_wgrads) = grad_fn(inputs, test_masks, test_params, test_others, - test_layer) + ref_out, (ref_dgrads, ref_wgrads) = grad_fn( + inputs, ref_masks, ref_params, ref_others, ref_layer + ) + test_out, (test_dgrads, test_wgrads) = grad_fn( + inputs, test_masks, test_params, test_others, test_layer + ) assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol) @@ -278,19 +296,20 @@ class BaseRunner: class EncoderRunner(BaseRunner): """Encoder runner implementations""" + layer_type = TransformerLayerType.ENCODER reference_layer = RefEncoderLayer transformations = { - 'attention/qkv/scale': 'pre_attention_layer_norm/scale', - 'attention/qkv/ln_bias': 'pre_attention_layer_norm/ln_bias', - 'attention/query/scale': 'pre_attention_layer_norm/scale', - 'attention/query/ln_bias': 'pre_attention_layer_norm/ln_bias', - 'mlp/wi_kernel': 'mlp/wi/kernel', - 'mlp/wi_bias': 'mlp/wi/bias', - 'mlp/wo_kernel': 'mlp/wo/kernel', - 'mlp/wo_bias': 'mlp/wo/bias', - 'mlp/scale': 'pre_mlp_layer_norm/scale', - 'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias', + "attention/qkv/scale": "pre_attention_layer_norm/scale", + "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias", + "attention/query/scale": "pre_attention_layer_norm/scale", + "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias", + "mlp/wi_kernel": "mlp/wi/kernel", + "mlp/wi_bias": "mlp/wi/bias", + "mlp/wo_kernel": "mlp/wo/kernel", + "mlp/wo_bias": "mlp/wo/bias", + "mlp/scale": "pre_mlp_layer_norm/scale", + "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias", } def generate_inputs(self, data_shape, dtype): @@ -307,13 +326,13 @@ class EncoderRunner(BaseRunner): padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) - if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']: + if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]: mask = causal_mask else: mask = padded_mask ref_masks = (1 - mask,) - test_masks = (None, mask) # The second arg of Transformer is encoded tokens. + test_masks = (None, mask) # The second arg of Transformer is encoded tokens. return inputs, (ref_masks, test_masks) @@ -322,23 +341,24 @@ class DecoderRunner(BaseRunner): """ Decoder runner implementations """ + layer_type = TransformerLayerType.DECODER reference_layer = RefDecoderLayer transformations = { - 'encoder_decoder_attention/qkv/scale': 'pre_cross_attention_layer_norm/scale', - 'encoder_decoder_attention/qkv/ln_bias': 'pre_cross_attention_layer_norm/ln_bias', - 'encoder_decoder_attention/query/scale': 'pre_cross_attention_layer_norm/scale', - 'encoder_decoder_attention/query/ln_bias': 'pre_cross_attention_layer_norm/ln_bias', - 'self_attention/qkv/scale': 'pre_self_attention_layer_norm/scale', - 'self_attention/qkv/ln_bias': 'pre_self_attention_layer_norm/ln_bias', - 'self_attention/query/scale': 'pre_self_attention_layer_norm/scale', - 'self_attention/query/ln_bias': 'pre_self_attention_layer_norm/ln_bias', - 'mlp/wi_kernel': 'mlp/wi/kernel', - 'mlp/wi_bias': 'mlp/wi/bias', - 'mlp/wo_kernel': 'mlp/wo/kernel', - 'mlp/wo_bias': 'mlp/wo/bias', - 'mlp/scale': 'pre_mlp_layer_norm/scale', - 'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias', + "encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale", + "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias", + "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale", + "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias", + "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale", + "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias", + "self_attention/query/scale": "pre_self_attention_layer_norm/scale", + "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias", + "mlp/wi_kernel": "mlp/wi/kernel", + "mlp/wi_bias": "mlp/wi/bias", + "mlp/wo_kernel": "mlp/wo/kernel", + "mlp/wo_bias": "mlp/wo/bias", + "mlp/scale": "pre_mlp_layer_norm/scale", + "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias", } def generate_inputs(self, data_shape, dtype): @@ -352,12 +372,14 @@ class DecoderRunner(BaseRunner): data_rng = jax.random.PRNGKey(0) data_rng_0, data_rng_1 = jax.random.split(data_rng, 2) - inputs = (jax.random.normal(data_rng_0, data_shape, - dtype), jax.random.normal(data_rng_1, data_shape, dtype)) + inputs = ( + jax.random.normal(data_rng_0, data_shape, dtype), + jax.random.normal(data_rng_1, data_shape, dtype), + ) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) - if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']: + if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]: self_mask = causal_mask else: self_mask = padded_mask @@ -368,27 +390,28 @@ class DecoderRunner(BaseRunner): return inputs, (ref_masks, test_masks) -@pytest.mark.parametrize('data_shape', DATA_SHAPE) -@pytest.mark.parametrize('dtype', DTYPE) -@pytest.mark.parametrize('attrs', ATTRS) -class BaseTester(): +@pytest.mark.parametrize("data_shape", DATA_SHAPE) +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("attrs", ATTRS) +class BaseTester: """ Pytest interface to invoke the runner """ + runner = BaseRunner def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" - FP8Helper.finalize() # Ensure FP8 disabled. + FP8Helper.finalize() # Ensure FP8 disabled. self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" - FP8Helper.finalize() # Ensure FP8 disabled. + FP8Helper.finalize() # Ensure FP8 disabled. self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('fp8_format', FP8_FORMATS) + @pytest.mark.parametrize("fp8_format", FP8_FORMATS) def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format): """Test forward with fp8 enabled""" FP8Helper.initialize(fp8_format=fp8_format) @@ -396,7 +419,7 @@ class BaseTester(): FP8Helper.finalize() @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('fp8_format', FP8_FORMATS) + @pytest.mark.parametrize("fp8_format", FP8_FORMATS) def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format): """Test backward with fp8 enabled""" FP8Helper.initialize(fp8_format=fp8_format) @@ -408,6 +431,7 @@ class TestEncoderLayer(BaseTester): """ Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder) """ + runner = EncoderRunner @@ -415,4 +439,5 @@ class TestDecoderLayer(BaseTester): """ Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder) """ + runner = DecoderRunner diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 453a8393ed5323fb8f631a393d81d8eecf574ab8..92a6c80028a2be38b8cfb5bb71ce86e9160276d5 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -37,13 +37,13 @@ from transformer_engine.jax.softmax import SoftmaxType is_fp8_supported, reason = is_fp8_available() -DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H) +DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H) DTYPE = [jnp.float32, jnp.bfloat16] ENABLE_FP8 = [False, True] FP8_FORMATS = [Format.E4M3, Format.HYBRID] -@pytest.fixture(autouse=True, scope='module') +@pytest.fixture(autouse=True, scope="module") def enable_fused_attn(): """ Enable fused attn for hopper+ arch. @@ -58,19 +58,16 @@ def enable_fused_attn(): def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): for key in ref_fd: - assert key in test_fd, \ - f"{key} not found in test dict {test_fd}" - assert isinstance(test_fd[key], type(ref_fd[key])), \ - f"The data type is not match between ref and test " \ - f" Dict on {key=}" + assert key in test_fd, f"{key} not found in test dict {test_fd}" + assert isinstance( + test_fd[key], type(ref_fd[key]) + ), f"The data type is not match between ref and test Dict on {key=}" if isinstance(ref_fd[key], Dict): compare_dict(ref_fd[key], test_fd[key], rtol, atol) else: - assert_allclose(ref_fd[key], - test_fd[key], - rtol=rtol, - atol=atol, - err_msg=f"{key=} is not close") + assert_allclose( + ref_fd[key], test_fd[key], rtol=rtol, atol=atol, err_msg=f"{key=} is not close" + ) class TestLayer: @@ -105,9 +102,10 @@ class TestLayer: lyr_name = self.get_layer_name() - if 'params' in flax_variables: - synced_praxis_variables['params'][lyr_name]['cld'] = \ - flax.core.unfreeze(flax_variables['params']) + if "params" in flax_variables: + synced_praxis_variables["params"][lyr_name]["cld"] = flax.core.unfreeze( + flax_variables["params"] + ) return synced_praxis_variables, flax_variables @@ -116,23 +114,19 @@ class TestLayer: lyr_name = self.get_layer_name() - if 'params' in synced_praxis_grads: - synced_praxis_grads['params'] = \ - synced_praxis_grads['params'][lyr_name]['cld'] + if "params" in synced_praxis_grads: + synced_praxis_grads["params"] = synced_praxis_grads["params"][lyr_name]["cld"] if FP8Helper.is_fp8_enabled(): - synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \ - synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld'] + synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = synced_praxis_grads[ + FP8Helper.FP8_COLLECTION_NAME + ][lyr_name]["cld"] return synced_praxis_grads, flax.core.unfreeze(flax_wgrads) - def forward_backward_runner(self, - data_shape, - dtype, - praxis_p, - flax_cls, - rtol=1e-05, - atol=1e-08): + def forward_backward_runner( + self, data_shape, dtype, praxis_p, flax_cls, rtol=1e-05, atol=1e-08 + ): init_key = jax.random.PRNGKey(seed=1234) test_inputs = self.input_getter(data_shape, dtype) @@ -148,28 +142,33 @@ class TestLayer: if "params_axes" in flax_variables: flax_variables, _ = flax.core.pop(flax_variables, "params_axes") if FP8Helper.is_fp8_enabled(): - flax_variables, _ = flax.core.pop(flax_variables, - FP8Helper.FP8_COLLECTION_NAME + "_axes") + flax_variables, _ = flax.core.pop( + flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes" + ) praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables) iter_times = 5 if FP8Helper.is_fp8_enabled() else 1 for _ in range(iter_times): - praxis_loss, praxis_wgrads, praxis_dgrad = \ - TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs) - flax_loss, flax_wgrads, flax_dgrad = \ - TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs) + praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads( + praxis_layer, praxis_variables, *test_inputs + ) + flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads( + flax_layer, flax_variables, *test_inputs + ) if FP8Helper.is_fp8_enabled(): - praxis_wgrads.pop('params') + praxis_wgrads.pop("params") praxis_variables = update_collections(praxis_wgrads, praxis_variables) - flax_wgrads, _ = flax.core.pop(flax_wgrads, 'params') + flax_wgrads, _ = flax.core.pop(flax_wgrads, "params") flax_variables = update_collections(flax_wgrads, flax_variables) - praxis_loss, praxis_wgrads, praxis_dgrad = \ - TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs) - flax_loss, flax_wgrads, flax_dgrad = \ - TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs) + praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads( + praxis_layer, praxis_variables, *test_inputs + ) + flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads( + flax_layer, flax_variables, *test_inputs + ) assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol) assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol) @@ -179,18 +178,13 @@ class TestLayer: class LayerNormAttr: - LN_TYPE = 'layernorm_type' - ZERO_CEN = 'zero_centered_gamma' - ATTRS = [{ - LN_TYPE: "layernorm", - ZERO_CEN: False - }, { - LN_TYPE: "layernorm", - ZERO_CEN: True - }, { - LN_TYPE: "rmsnorm", - ZERO_CEN: False - }] + LN_TYPE = "layernorm_type" + ZERO_CEN = "zero_centered_gamma" + ATTRS = [ + {LN_TYPE: "layernorm", ZERO_CEN: False}, + {LN_TYPE: "layernorm", ZERO_CEN: True}, + {LN_TYPE: "rmsnorm", ZERO_CEN: False}, + ] class TestLayerNorm(TestLayer): @@ -200,7 +194,7 @@ class TestLayerNorm(TestLayer): return (jax.random.normal(data_key, shape, dtype),) def get_layer_name(self): - return 'layer_norm' + return "layer_norm" def generate_praxis_p_and_flax_cls(self, dtype, attrs): layernorm_type = attrs[LayerNormAttr.LN_TYPE] @@ -209,63 +203,59 @@ class TestLayerNorm(TestLayer): bias_init = WeightInit.Constant(0.0) transpose_batch_sequence = False - praxis_p = pax_fiddle.Config(LayerNorm, - name='layer_norm', - dtype=dtype, - layernorm_type=layernorm_type, - zero_centered_gamma=zero_centered_gamma, - scale_init=scale_init, - bias_init=bias_init, - transpose_batch_sequence=transpose_batch_sequence) - flax_cls = partial(flax_LayerNorm, - layernorm_type=layernorm_type, - zero_centered_gamma=zero_centered_gamma, - scale_init=scale_init, - bias_init=TransformerEngineBaseLayer.generate_params_init( - "ln_bias", bias_init), - dtype=dtype, - transpose_batch_sequence=transpose_batch_sequence) + praxis_p = pax_fiddle.Config( + LayerNorm, + name="layer_norm", + dtype=dtype, + layernorm_type=layernorm_type, + zero_centered_gamma=zero_centered_gamma, + scale_init=scale_init, + bias_init=bias_init, + transpose_batch_sequence=transpose_batch_sequence, + ) + flax_cls = partial( + flax_LayerNorm, + layernorm_type=layernorm_type, + zero_centered_gamma=zero_centered_gamma, + scale_init=scale_init, + bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", bias_init), + dtype=dtype, + transpose_batch_sequence=transpose_batch_sequence, + ) return praxis_p, flax_cls - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', LayerNormAttr.ATTRS) + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", LayerNormAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) class FusedSoftmaxAttr: - SCALE_FACTOR = 'scale_factor' - ST_TYPE = 'softmax_type' - ATTRS = [{ - SCALE_FACTOR: 0.0, - ST_TYPE: SoftmaxType.SCALED - }, { - SCALE_FACTOR: 0.0, - ST_TYPE: SoftmaxType.SCALED_MASKED - }, { - SCALE_FACTOR: 0.0, - ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED - }] + SCALE_FACTOR = "scale_factor" + ST_TYPE = "softmax_type" + ATTRS = [ + {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED}, + {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_MASKED}, + {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED}, + ] class TestFusedSoftmax(TestLayer): def input_getter(self, shape, dtype): data_key = jax.random.PRNGKey(seed=1234) - return jax.random.normal(data_key, shape, dtype), \ - jnp.ones(shape, dtype=jnp.uint8) # Masks + return jax.random.normal(data_key, shape, dtype), jnp.ones(shape, dtype=jnp.uint8) # Masks def generate_praxis_p_and_flax_cls(self, dtype, attrs): scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR] softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE] - praxis_p = pax_fiddle.Config(FusedSoftmax, - name='fused_softmax', - scale_factor=scale_factor, - softmax_type=softmax_type) + praxis_p = pax_fiddle.Config( + FusedSoftmax, name="fused_softmax", scale_factor=scale_factor, softmax_type=softmax_type + ) flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type) return praxis_p, flax_cls @@ -276,34 +266,28 @@ class TestFusedSoftmax(TestLayer): def sync_wgrads(self, praxis_wgrads, flax_wgrads): return praxis_wgrads, flax_wgrads - @pytest.mark.parametrize('data_shape', [(32, 1, 128, 128), (32, 1, 512, 128)]) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', FusedSoftmaxAttr.ATTRS) + @pytest.mark.parametrize("data_shape", [(32, 1, 128, 128), (32, 1, 512, 128)]) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", FusedSoftmaxAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): - if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and \ - (data_shape[-2] != data_shape[-1]): - pass # Skip, due to not support + if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and ( + data_shape[-2] != data_shape[-1] + ): + pass # Skip, due to not support else: praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) class LinearAttr: - FEATURE = 'features' - USE_BIAS = 'use_bias' - ATTRS = [{ - FEATURE: 512, - USE_BIAS: False - }, { - FEATURE: 512, - USE_BIAS: True - }, { - FEATURE: 1024, - USE_BIAS: False - }, { - FEATURE: 1024, - USE_BIAS: True - }] + FEATURE = "features" + USE_BIAS = "use_bias" + ATTRS = [ + {FEATURE: 512, USE_BIAS: False}, + {FEATURE: 512, USE_BIAS: True}, + {FEATURE: 1024, USE_BIAS: False}, + {FEATURE: 1024, USE_BIAS: True}, + ] class TestLinear(TestLayer): @@ -313,7 +297,7 @@ class TestLinear(TestLayer): return (jax.random.normal(data_key, shape, dtype),) def get_layer_name(self): - return 'linear' + return "linear" def generate_praxis_p_and_flax_cls(self, dtype, attrs): out_features = attrs[LinearAttr.FEATURE] @@ -323,15 +307,17 @@ class TestLinear(TestLayer): axis = -1 transpose_batch_sequence = False - praxis_p = pax_fiddle.Config(Linear, - name='linear', - dtype=dtype, - out_features=out_features, - params_init=kernel_init, - use_bias=use_bias, - bias_init=bias_init, - axis=axis, - transpose_batch_sequence=transpose_batch_sequence) + praxis_p = pax_fiddle.Config( + Linear, + name="linear", + dtype=dtype, + out_features=out_features, + params_init=kernel_init, + use_bias=use_bias, + bias_init=bias_init, + axis=axis, + transpose_batch_sequence=transpose_batch_sequence, + ) flax_cls = partial( DenseGeneral, features=out_features, @@ -340,29 +326,26 @@ class TestLinear(TestLayer): bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), axis=axis, dtype=dtype, - transpose_batch_sequence=transpose_batch_sequence) + transpose_batch_sequence=transpose_batch_sequence, + ) return praxis_p, flax_cls - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', LinearAttr.ATTRS) + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", LinearAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', LinearAttr.ATTRS) - @pytest.mark.parametrize('fp8_format', FP8_FORMATS) - def test_forward_backward_fp8(self, - data_shape, - dtype, - attrs, - fp8_format, - rtol=1e-05, - atol=1e-08): + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", LinearAttr.ATTRS) + @pytest.mark.parametrize("fp8_format", FP8_FORMATS) + def test_forward_backward_fp8( + self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08 + ): ds = DelayedScaling(fp8_format=fp8_format) with fp8_autocast(enabled=True, fp8_recipe=ds): @@ -371,54 +354,20 @@ class TestLinear(TestLayer): class LayerNormLinearAttr: - FEATURE = 'features' - USE_BIAS = 'use_bias' - ENABLE_LN = 'enable_layernorm' - LN_TYPE = 'layernorm_type' - ZERO_CEN = 'zero_centered_gamma' - ATTRS = [{ - FEATURE: 512, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False - }, { - FEATURE: 512, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False - }, { - FEATURE: 512, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True - }, { - FEATURE: 512, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True - }, { - FEATURE: 512, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False - }, { - FEATURE: 512, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False - }, { - FEATURE: 512, - USE_BIAS: True, - ENABLE_LN: False, - LN_TYPE: 'layernorm', - ZERO_CEN: False - }] + FEATURE = "features" + USE_BIAS = "use_bias" + ENABLE_LN = "enable_layernorm" + LN_TYPE = "layernorm_type" + ZERO_CEN = "zero_centered_gamma" + ATTRS = [ + {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False}, + {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False}, + {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True}, + {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True}, + {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False}, + {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False}, + {FEATURE: 512, USE_BIAS: True, ENABLE_LN: False, LN_TYPE: "layernorm", ZERO_CEN: False}, + ] class TestLayerNormLinear(TestLayer): @@ -428,7 +377,7 @@ class TestLayerNormLinear(TestLayer): return (jax.random.normal(data_key, shape, dtype),) def get_layer_name(self): - return 'ln_linear' + return "ln_linear" def generate_praxis_p_and_flax_cls(self, dtype, attrs): out_features = attrs[LayerNormLinearAttr.FEATURE] @@ -441,18 +390,20 @@ class TestLayerNormLinear(TestLayer): axis = -1 transpose_batch_sequence = False - praxis_p = pax_fiddle.Config(LayerNormLinear, - name='ln_linear', - dtype=dtype, - out_features=out_features, - enable_layernorm=enable_layernorm, - layernorm_type=layernorm_type, - zero_centered_gamma=zero_centered_gamma, - params_init=kernel_init, - use_bias=use_bias, - bias_init=bias_init, - axis=axis, - transpose_batch_sequence=transpose_batch_sequence) + praxis_p = pax_fiddle.Config( + LayerNormLinear, + name="ln_linear", + dtype=dtype, + out_features=out_features, + enable_layernorm=enable_layernorm, + layernorm_type=layernorm_type, + zero_centered_gamma=zero_centered_gamma, + params_init=kernel_init, + use_bias=use_bias, + bias_init=bias_init, + axis=axis, + transpose_batch_sequence=transpose_batch_sequence, + ) flax_cls = partial( LayerNormDenseGeneral, features=out_features, @@ -464,29 +415,26 @@ class TestLayerNormLinear(TestLayer): bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), axis=axis, dtype=dtype, - transpose_batch_sequence=transpose_batch_sequence) + transpose_batch_sequence=transpose_batch_sequence, + ) return praxis_p, flax_cls - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS) + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS) - @pytest.mark.parametrize('fp8_format', FP8_FORMATS) - def test_forward_backward_fp8(self, - data_shape, - dtype, - attrs, - fp8_format, - rtol=1e-05, - atol=1e-08): + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS) + @pytest.mark.parametrize("fp8_format", FP8_FORMATS) + def test_forward_backward_fp8( + self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08 + ): ds = DelayedScaling(fp8_format=fp8_format) with fp8_autocast(enabled=True, fp8_recipe=ds): @@ -495,62 +443,70 @@ class TestLayerNormLinear(TestLayer): class LayerNormMLPAttr: - INTERMEDIATE_DIM = 'intermediate_dim' - USE_BIAS = 'use_bias' - ENABLE_LN = 'enable_layernorm' - LN_TYPE = 'layernorm_type' - ZERO_CEN = 'zero_centered_gamma' - ACTIVATION = 'activations' - ATTRS = [{ - INTERMEDIATE_DIM: 2048, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('relu',) - }, { - INTERMEDIATE_DIM: 2048, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ACTIVATION: ('relu',) - }, { - INTERMEDIATE_DIM: 2048, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('relu',) - }, { - INTERMEDIATE_DIM: 2048, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear') - }, { - INTERMEDIATE_DIM: 2048, - USE_BIAS: False, - ENABLE_LN: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear') - }, { - INTERMEDIATE_DIM: 2048, - USE_BIAS: True, - ENABLE_LN: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('silu', 'linear') - }, { - INTERMEDIATE_DIM: 2048, - USE_BIAS: False, - ENABLE_LN: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('silu', 'linear') - }] + INTERMEDIATE_DIM = "intermediate_dim" + USE_BIAS = "use_bias" + ENABLE_LN = "enable_layernorm" + LN_TYPE = "layernorm_type" + ZERO_CEN = "zero_centered_gamma" + ACTIVATION = "activations" + ATTRS = [ + { + INTERMEDIATE_DIM: 2048, + USE_BIAS: True, + ENABLE_LN: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + }, + { + INTERMEDIATE_DIM: 2048, + USE_BIAS: True, + ENABLE_LN: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ACTIVATION: ("relu",), + }, + { + INTERMEDIATE_DIM: 2048, + USE_BIAS: True, + ENABLE_LN: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + }, + { + INTERMEDIATE_DIM: 2048, + USE_BIAS: True, + ENABLE_LN: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + }, + { + INTERMEDIATE_DIM: 2048, + USE_BIAS: False, + ENABLE_LN: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + }, + { + INTERMEDIATE_DIM: 2048, + USE_BIAS: True, + ENABLE_LN: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("silu", "linear"), + }, + { + INTERMEDIATE_DIM: 2048, + USE_BIAS: False, + ENABLE_LN: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("silu", "linear"), + }, + ] class TestLayerNormMLP(TestLayer): @@ -560,7 +516,7 @@ class TestLayerNormMLP(TestLayer): return (jax.random.normal(data_key, shape, dtype),) def get_layer_name(self): - return 'ln_mlp' + return "ln_mlp" def generate_praxis_p_and_flax_cls(self, dtype, attrs): intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM] @@ -574,20 +530,22 @@ class TestLayerNormMLP(TestLayer): axis = -1 transpose_batch_sequence = False - praxis_p = pax_fiddle.Config(LayerNormMLP, - name='ln_mlp', - dtype=dtype, - intermediate_dim=intermediate_dim, - enable_layernorm=enable_layernorm, - layernorm_type=layernorm_type, - zero_centered_gamma=zero_centered_gamma, - params_init=kernel_init, - use_bias=use_bias, - bias_init=bias_init, - activations=activations, - intermediate_dropout_rate=0.0, - axis=axis, - transpose_batch_sequence=transpose_batch_sequence) + praxis_p = pax_fiddle.Config( + LayerNormMLP, + name="ln_mlp", + dtype=dtype, + intermediate_dim=intermediate_dim, + enable_layernorm=enable_layernorm, + layernorm_type=layernorm_type, + zero_centered_gamma=zero_centered_gamma, + params_init=kernel_init, + use_bias=use_bias, + bias_init=bias_init, + activations=activations, + intermediate_dropout_rate=0.0, + axis=axis, + transpose_batch_sequence=transpose_batch_sequence, + ) flax_cls = partial( flax_LayerNormMLP, intermediate_dim=intermediate_dim, @@ -601,29 +559,26 @@ class TestLayerNormMLP(TestLayer): intermediate_dropout_rate=0.0, axis=axis, dtype=dtype, - transpose_batch_sequence=transpose_batch_sequence) + transpose_batch_sequence=transpose_batch_sequence, + ) return praxis_p, flax_cls - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS) + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS) - @pytest.mark.parametrize('fp8_format', FP8_FORMATS) - def test_forward_backward_fp8(self, - data_shape, - dtype, - attrs, - fp8_format, - rtol=1e-05, - atol=1e-08): + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS) + @pytest.mark.parametrize("fp8_format", FP8_FORMATS) + def test_forward_backward_fp8( + self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08 + ): ds = DelayedScaling(fp8_format=fp8_format) with fp8_autocast(enabled=True, fp8_recipe=ds): @@ -634,35 +589,40 @@ class TestLayerNormMLP(TestLayer): class TestRelativePositionBias(TestLayer): def get_layer_name(self): - return 'relative_position_bias' + return "relative_position_bias" def generate_praxis_p_and_flax_cls(self, dtype, attrs): num_buckets = 32 max_distance = 128 num_attention_heads = 64 - rb_stddev = (num_attention_heads * num_buckets)**-0.5 + rb_stddev = (num_attention_heads * num_buckets) ** -0.5 embedding_init = WeightInit.Gaussian(rb_stddev) - praxis_p = pax_fiddle.Config(RelativePositionBiases, - name='relative_position_bias', - dtype=dtype, - num_buckets=num_buckets, - max_distance=max_distance, - num_attention_heads=num_attention_heads, - embedding_init=embedding_init) - flax_cls = partial(flax_RelativePositionBiases, - num_buckets=num_buckets, - max_distance=max_distance, - num_attention_heads=num_attention_heads, - embedding_init=TransformerEngineBaseLayer.generate_params_init( - "rel_embedding", embedding_init), - dtype=dtype) + praxis_p = pax_fiddle.Config( + RelativePositionBiases, + name="relative_position_bias", + dtype=dtype, + num_buckets=num_buckets, + max_distance=max_distance, + num_attention_heads=num_attention_heads, + embedding_init=embedding_init, + ) + flax_cls = partial( + flax_RelativePositionBiases, + num_buckets=num_buckets, + max_distance=max_distance, + num_attention_heads=num_attention_heads, + embedding_init=TransformerEngineBaseLayer.generate_params_init( + "rel_embedding", embedding_init + ), + dtype=dtype, + ) return praxis_p, flax_cls - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', [{}]) + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", [{}]) def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) @@ -678,53 +638,64 @@ class TestRelativePositionBias(TestLayer): if "params_axes" in flax_variables: flax_variables, _ = flax.core.pop(flax_variables, "params_axes") if FP8Helper.is_fp8_enabled(): - flax_variables, _ = flax.core.pop(flax_variables, - FP8Helper.FP8_COLLECTION_NAME + "_axes") + flax_variables, _ = flax.core.pop( + flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes" + ) praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables) - praxis_loss= \ - TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False) - flax_loss = \ - TestLayer.loss(flax_variables, *test_input, module=flax_layer, mean_out=False) + praxis_loss = TestLayer.loss( + praxis_variables, *test_input, module=praxis_layer, mean_out=False + ) + flax_loss = TestLayer.loss( + flax_variables, *test_input, module=flax_layer, mean_out=False + ) assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol) class DotProductAttnAttr: - ATTN_MASK_TYPE = 'attn_mask_type' - NUM_GQA_GROUPS = 'num_gqa_groups' - TRANSPOSE_BS = 'transpose_batch_sequence' - SCALE_FACTOR = 'scale_factor' - ATTRS = [{ - ATTN_MASK_TYPE: 'padding', - TRANSPOSE_BS: True, - SCALE_FACTOR: 0.125, - }, { - ATTN_MASK_TYPE: 'padding_causal', - TRANSPOSE_BS: True, - SCALE_FACTOR: 0.125, - }, { - ATTN_MASK_TYPE: 'causal', - TRANSPOSE_BS: True, - SCALE_FACTOR: 0.125, - }, { - ATTN_MASK_TYPE: 'padding', - TRANSPOSE_BS: False, - SCALE_FACTOR: 0.125, - }, { - ATTN_MASK_TYPE: 'padding_causal', - TRANSPOSE_BS: False, - SCALE_FACTOR: 2., - }, { - ATTN_MASK_TYPE: 'causal', - TRANSPOSE_BS: False, - SCALE_FACTOR: 1., - }, { - ATTN_MASK_TYPE: 'no_mask', - TRANSPOSE_BS: False, - SCALE_FACTOR: 1., - }] + ATTN_MASK_TYPE = "attn_mask_type" + NUM_GQA_GROUPS = "num_gqa_groups" + TRANSPOSE_BS = "transpose_batch_sequence" + SCALE_FACTOR = "scale_factor" + ATTRS = [ + { + ATTN_MASK_TYPE: "padding", + TRANSPOSE_BS: True, + SCALE_FACTOR: 0.125, + }, + { + ATTN_MASK_TYPE: "padding_causal", + TRANSPOSE_BS: True, + SCALE_FACTOR: 0.125, + }, + { + ATTN_MASK_TYPE: "causal", + TRANSPOSE_BS: True, + SCALE_FACTOR: 0.125, + }, + { + ATTN_MASK_TYPE: "padding", + TRANSPOSE_BS: False, + SCALE_FACTOR: 0.125, + }, + { + ATTN_MASK_TYPE: "padding_causal", + TRANSPOSE_BS: False, + SCALE_FACTOR: 2.0, + }, + { + ATTN_MASK_TYPE: "causal", + TRANSPOSE_BS: False, + SCALE_FACTOR: 1.0, + }, + { + ATTN_MASK_TYPE: "no_mask", + TRANSPOSE_BS: False, + SCALE_FACTOR: 1.0, + }, + ] class TestDotProductAttn(TestLayer): @@ -737,11 +708,12 @@ class TestDotProductAttn(TestLayer): shape = (shape[1], shape[0]) + shape[2:] mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) return [ - *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask + *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), + mask, ] def get_layer_name(self): - return 'dot_product_attn' + return "dot_product_attn" def generate_praxis_p_and_flax_cls(self, dtype, attrs): head_dim = 64 @@ -750,27 +722,31 @@ class TestDotProductAttn(TestLayer): attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE] transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS] - praxis_p = pax_fiddle.Config(DotProductAttention, - name='mha', - dtype=dtype, - head_dim=head_dim, - num_attention_heads=num_attention_heads, - num_gqa_groups=num_gqa_groups, - attn_mask_type=attn_mask_type, - transpose_batch_sequence=transpose_batch_sequence) - flax_cls = partial(flax_DotProductAttention, - dtype=dtype, - head_dim=head_dim, - num_attention_heads=num_attention_heads, - num_gqa_groups=num_gqa_groups, - attn_mask_type=attn_mask_type, - transpose_batch_sequence=transpose_batch_sequence) + praxis_p = pax_fiddle.Config( + DotProductAttention, + name="mha", + dtype=dtype, + head_dim=head_dim, + num_attention_heads=num_attention_heads, + num_gqa_groups=num_gqa_groups, + attn_mask_type=attn_mask_type, + transpose_batch_sequence=transpose_batch_sequence, + ) + flax_cls = partial( + flax_DotProductAttention, + dtype=dtype, + head_dim=head_dim, + num_attention_heads=num_attention_heads, + num_gqa_groups=num_gqa_groups, + attn_mask_type=attn_mask_type, + transpose_batch_sequence=transpose_batch_sequence, + ) return praxis_p, flax_cls - @pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)]) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS) + @pytest.mark.parametrize("data_shape", [(32, 128, 16, 64)]) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", DotProductAttnAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): self.attrs = attrs praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) @@ -778,113 +754,125 @@ class TestDotProductAttn(TestLayer): class MultiHeadAttnAttr: - USE_BIAS = 'use_bias' - LN_TYPE = 'layernorm_type' - ATTN_MASK_TYPE = 'attn_mask_type' - ZERO_CEN = 'zero_centered_gamma' - NUM_ATTN_HEADS = 'num_attention_heads' - NUM_GQA_GROUPS = 'num_gqa_groups' - TRANSPOSE_BS = 'transpose_batch_sequence' - ENABLE_ROPE = 'enable_rotary_pos_emb' - ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' - LORA_SCOPE = 'low_rank_adaptation_scope' - ATTRS = [{ - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - ATTN_MASK_TYPE: 'padding', - TRANSPOSE_BS: True, - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - ATTN_MASK_TYPE: 'padding', - TRANSPOSE_BS: False, - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - ATTN_MASK_TYPE: 'padding', - TRANSPOSE_BS: True, - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - ATTN_MASK_TYPE: 'causal', - TRANSPOSE_BS: False, - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - ATTN_MASK_TYPE: 'causal', - TRANSPOSE_BS: True, - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - ATTN_MASK_TYPE: 'causal', - TRANSPOSE_BS: False, - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - NUM_ATTN_HEADS: 8, - NUM_GQA_GROUPS: 4, - ATTN_MASK_TYPE: 'causal', - TRANSPOSE_BS: True, - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ENABLE_ROPE: True, - ROPE_GROUP_METHOD: 'consecutive', - NUM_ATTN_HEADS: 8, - NUM_GQA_GROUPS: 4, - ATTN_MASK_TYPE: 'causal', - TRANSPOSE_BS: False, - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ENABLE_ROPE: True, - ROPE_GROUP_METHOD: 'alternate', - NUM_ATTN_HEADS: 8, - NUM_GQA_GROUPS: 4, - ATTN_MASK_TYPE: 'causal', - TRANSPOSE_BS: True, - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - ATTN_MASK_TYPE: 'padding', - LORA_SCOPE: 'all', - TRANSPOSE_BS: False, - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - ATTN_MASK_TYPE: 'causal', - LORA_SCOPE: 'all', - TRANSPOSE_BS: True, - }] + USE_BIAS = "use_bias" + LN_TYPE = "layernorm_type" + ATTN_MASK_TYPE = "attn_mask_type" + ZERO_CEN = "zero_centered_gamma" + NUM_ATTN_HEADS = "num_attention_heads" + NUM_GQA_GROUPS = "num_gqa_groups" + TRANSPOSE_BS = "transpose_batch_sequence" + ENABLE_ROPE = "enable_rotary_pos_emb" + ROPE_GROUP_METHOD = "rotary_pos_emb_group_method" + LORA_SCOPE = "low_rank_adaptation_scope" + ATTRS = [ + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + ATTN_MASK_TYPE: "padding", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + ATTN_MASK_TYPE: "padding", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + ATTN_MASK_TYPE: "padding", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + ATTN_MASK_TYPE: "causal", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + ATTN_MASK_TYPE: "causal", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + ATTN_MASK_TYPE: "causal", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + NUM_ATTN_HEADS: 8, + NUM_GQA_GROUPS: 4, + ATTN_MASK_TYPE: "causal", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ENABLE_ROPE: True, + ROPE_GROUP_METHOD: "consecutive", + NUM_ATTN_HEADS: 8, + NUM_GQA_GROUPS: 4, + ATTN_MASK_TYPE: "causal", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ENABLE_ROPE: True, + ROPE_GROUP_METHOD: "alternate", + NUM_ATTN_HEADS: 8, + NUM_GQA_GROUPS: 4, + ATTN_MASK_TYPE: "causal", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + ATTN_MASK_TYPE: "padding", + LORA_SCOPE: "all", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + ATTN_MASK_TYPE: "causal", + LORA_SCOPE: "all", + TRANSPOSE_BS: True, + }, + ] class TestMultiHeadAttn(TestLayer): @@ -899,13 +887,16 @@ class TestMultiHeadAttn(TestLayer): return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask] def get_layer_name(self): - return 'multi_head_attn' + return "multi_head_attn" def generate_praxis_p_and_flax_cls(self, dtype, attrs): head_dim = 64 num_attention_heads = 16 - num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \ - if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None + num_gqa_groups = ( + attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] + if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs + else None + ) layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE] zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN] kernel_init = WeightInit.Gaussian(1.0) @@ -916,35 +907,37 @@ class TestMultiHeadAttn(TestLayer): attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE] enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE] rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD] - low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none') + low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, "none") fuse_qkv_params = True transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS] scale_attn_logits = False scaled_query_init = True float32_logits = False - praxis_p = pax_fiddle.Config(MultiHeadAttention, - name='mha', - dtype=dtype, - head_dim=head_dim, - num_attention_heads=num_attention_heads, - num_gqa_groups=num_gqa_groups, - layernorm_type=layernorm_type, - zero_centered_gamma=zero_centered_gamma, - params_init=kernel_init, - use_bias=use_bias, - bias_init=bias_init, - return_layernorm_output=return_layernorm_output, - input_layernorm=input_layernorm, - attn_mask_type=attn_mask_type, - enable_rotary_pos_emb=enable_rotary_pos_emb, - rotary_pos_emb_group_method=rotary_pos_emb_group_method, - low_rank_adaptation_scope=low_rank_adaptation_scope, - fuse_qkv_params=fuse_qkv_params, - transpose_batch_sequence=transpose_batch_sequence, - scale_attn_logits=scale_attn_logits, - scaled_query_init=scaled_query_init, - float32_logits=float32_logits) + praxis_p = pax_fiddle.Config( + MultiHeadAttention, + name="mha", + dtype=dtype, + head_dim=head_dim, + num_attention_heads=num_attention_heads, + num_gqa_groups=num_gqa_groups, + layernorm_type=layernorm_type, + zero_centered_gamma=zero_centered_gamma, + params_init=kernel_init, + use_bias=use_bias, + bias_init=bias_init, + return_layernorm_output=return_layernorm_output, + input_layernorm=input_layernorm, + attn_mask_type=attn_mask_type, + enable_rotary_pos_emb=enable_rotary_pos_emb, + rotary_pos_emb_group_method=rotary_pos_emb_group_method, + low_rank_adaptation_scope=low_rank_adaptation_scope, + fuse_qkv_params=fuse_qkv_params, + transpose_batch_sequence=transpose_batch_sequence, + scale_attn_logits=scale_attn_logits, + scaled_query_init=scaled_query_init, + float32_logits=float32_logits, + ) flax_cls = partial( flax_MultiHeadAttention, dtype=dtype, @@ -966,30 +959,27 @@ class TestMultiHeadAttn(TestLayer): transpose_batch_sequence=transpose_batch_sequence, scale_attn_logits=scale_attn_logits, scaled_query_init=scaled_query_init, - float32_logits=float32_logits) + float32_logits=float32_logits, + ) return praxis_p, flax_cls - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS) + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): self.attrs = attrs praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS) - @pytest.mark.parametrize('fp8_format', FP8_FORMATS) - def test_forward_backward_fp8(self, - data_shape, - dtype, - attrs, - fp8_format, - rtol=1e-05, - atol=1e-08): + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS) + @pytest.mark.parametrize("fp8_format", FP8_FORMATS) + def test_forward_backward_fp8( + self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08 + ): self.attrs = attrs ds = DelayedScaling(fp8_format=fp8_format) with fp8_autocast(enabled=True, fp8_recipe=ds): @@ -998,252 +988,279 @@ class TestMultiHeadAttn(TestLayer): class TransformerLayerAttr: - USE_BIAS = 'use_bias' - LN_TYPE = 'layernorm_type' - ACTIVATION = 'activations' - LYR_TYPE = 'layer_type' - ZERO_CEN = 'zero_centered_gamma' - TRANSPOSE_BS = 'transpose_batch_sequence' - ENABLE_ROPE = 'enable_rotary_pos_emb' - ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' - LORA_SCOPE = 'low_rank_adaptation_scope' - ATTRS = [{ - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('relu',), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear'), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear'), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear'), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear'), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('gelu',), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False, - LORA_SCOPE: 'all' - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear'), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear'), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear'), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: True - }, { - USE_BIAS: True, - LN_TYPE: 'rmsnorm', - ZERO_CEN: False, - ACTIVATION: ('gelu', 'linear'), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ACTIVATION: ('gelu',), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: True, - ROPE_GROUP_METHOD: 'alternate', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ACTIVATION: ('gelu',), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: True, - ROPE_GROUP_METHOD: 'alternate', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ACTIVATION: ('gelu',), - LYR_TYPE: TransformerLayerType.ENCODER, - ENABLE_ROPE: True, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: True, - ACTIVATION: ('gelu',), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: True, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False - }, { - USE_BIAS: True, - LN_TYPE: 'layernorm', - ZERO_CEN: False, - ACTIVATION: ('gelu',), - LYR_TYPE: TransformerLayerType.DECODER, - ENABLE_ROPE: False, - ROPE_GROUP_METHOD: 'consecutive', - TRANSPOSE_BS: False, - LORA_SCOPE: 'all' - }] + USE_BIAS = "use_bias" + LN_TYPE = "layernorm_type" + ACTIVATION = "activations" + LYR_TYPE = "layer_type" + ZERO_CEN = "zero_centered_gamma" + TRANSPOSE_BS = "transpose_batch_sequence" + ENABLE_ROPE = "enable_rotary_pos_emb" + ROPE_GROUP_METHOD = "rotary_pos_emb_group_method" + LORA_SCOPE = "low_rank_adaptation_scope" + ATTRS = [ + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("gelu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + LORA_SCOPE: "all", + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: True, + }, + { + USE_BIAS: True, + LN_TYPE: "rmsnorm", + ZERO_CEN: False, + ACTIVATION: ("gelu", "linear"), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ACTIVATION: ("gelu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: True, + ROPE_GROUP_METHOD: "alternate", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ACTIVATION: ("gelu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: True, + ROPE_GROUP_METHOD: "alternate", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ACTIVATION: ("gelu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: True, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: True, + ACTIVATION: ("gelu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: True, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("gelu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + LORA_SCOPE: "all", + }, + ] class TestTransformer(TestLayer): @@ -1256,11 +1273,13 @@ class TestTransformer(TestLayer): shape = (shape[1], shape[0]) + shape[2:] mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) return [ - *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask + *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), + mask, + mask, ] def get_layer_name(self): - return 'transformerlayer' + return "transformerlayer" def generate_praxis_p_and_flax_cls(self, dtype, attrs): hidden_size = 512 @@ -1277,97 +1296,102 @@ class TestTransformer(TestLayer): layer_type = attrs[TransformerLayerAttr.LYR_TYPE] enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE] rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD] - low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none') + low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, "none") enable_relative_embedding = True - relative_embedding = pax_fiddle.Config(RelativePositionBiases, - dtype=dtype, - num_attention_heads=num_attention_heads) + relative_embedding = pax_fiddle.Config( + RelativePositionBiases, dtype=dtype, num_attention_heads=num_attention_heads + ) drop_path = 0.0 transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS] rel_embedding_init = RelativePositionBiases.generate_embedding_init( - relative_embedding.embedding_init, relative_embedding.num_attention_heads, - relative_embedding.num_buckets) + relative_embedding.embedding_init, + relative_embedding.num_attention_heads, + relative_embedding.num_buckets, + ) relative_embedding_flax_module = flax_RelativePositionBiases( num_buckets=relative_embedding.num_buckets, max_distance=relative_embedding.max_distance, num_attention_heads=relative_embedding.num_attention_heads, embedding_init=TransformerEngineBaseLayer.generate_params_init( - "rel_embedding", rel_embedding_init), + "rel_embedding", rel_embedding_init + ), embedding_axes=relative_embedding.embedding_axes, - dtype=relative_embedding.dtype) - - praxis_p = pax_fiddle.Config(TransformerLayer, - name='transformer_layer', - params_init=kernel_init, - dtype=dtype, - hidden_size=hidden_size, - mlp_hidden_size=mlp_hidden_size, - num_attention_heads=num_attention_heads, - layernorm_type=layernorm_type, - hidden_dropout=hidden_dropout, - attention_dropout=attention_dropout, - intermediate_dropout=intermediate_dropout, - mlp_activations=mlp_activations, - use_bias=use_bias, - bias_init=bias_init, - layer_type=layer_type, - enable_relative_embedding=enable_relative_embedding, - enable_rotary_pos_emb=enable_rotary_pos_emb, - rotary_pos_emb_group_method=rotary_pos_emb_group_method, - low_rank_adaptation_scope=low_rank_adaptation_scope, - relative_embedding=relative_embedding, - drop_path=drop_path, - transpose_batch_sequence=transpose_batch_sequence) - flax_cls = partial(flax_TransformerLayer, - dtype=dtype, - hidden_size=hidden_size, - mlp_hidden_size=mlp_hidden_size, - num_attention_heads=num_attention_heads, - layernorm_type=layernorm_type, - hidden_dropout=hidden_dropout, - attention_dropout=attention_dropout, - intermediate_dropout=intermediate_dropout, - mlp_activations=mlp_activations, - mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( - "mha_kernel", kernel_init), - mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init( - "mlp_kernel", kernel_init), - use_bias=use_bias, - bias_init=TransformerEngineBaseLayer.generate_params_init( - "bias", bias_init), - layer_type=layer_type, - enable_rotary_pos_emb=enable_rotary_pos_emb, - rotary_pos_emb_group_method=rotary_pos_emb_group_method, - enable_relative_embedding=enable_relative_embedding, - relative_embedding=relative_embedding_flax_module, - low_rank_adaptation_scope=low_rank_adaptation_scope, - drop_path=drop_path, - transpose_batch_sequence=transpose_batch_sequence) + dtype=relative_embedding.dtype, + ) + + praxis_p = pax_fiddle.Config( + TransformerLayer, + name="transformer_layer", + params_init=kernel_init, + dtype=dtype, + hidden_size=hidden_size, + mlp_hidden_size=mlp_hidden_size, + num_attention_heads=num_attention_heads, + layernorm_type=layernorm_type, + hidden_dropout=hidden_dropout, + attention_dropout=attention_dropout, + intermediate_dropout=intermediate_dropout, + mlp_activations=mlp_activations, + use_bias=use_bias, + bias_init=bias_init, + layer_type=layer_type, + enable_relative_embedding=enable_relative_embedding, + enable_rotary_pos_emb=enable_rotary_pos_emb, + rotary_pos_emb_group_method=rotary_pos_emb_group_method, + low_rank_adaptation_scope=low_rank_adaptation_scope, + relative_embedding=relative_embedding, + drop_path=drop_path, + transpose_batch_sequence=transpose_batch_sequence, + ) + flax_cls = partial( + flax_TransformerLayer, + dtype=dtype, + hidden_size=hidden_size, + mlp_hidden_size=mlp_hidden_size, + num_attention_heads=num_attention_heads, + layernorm_type=layernorm_type, + hidden_dropout=hidden_dropout, + attention_dropout=attention_dropout, + intermediate_dropout=intermediate_dropout, + mlp_activations=mlp_activations, + mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( + "mha_kernel", kernel_init + ), + mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init( + "mlp_kernel", kernel_init + ), + use_bias=use_bias, + bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), + layer_type=layer_type, + enable_rotary_pos_emb=enable_rotary_pos_emb, + rotary_pos_emb_group_method=rotary_pos_emb_group_method, + enable_relative_embedding=enable_relative_embedding, + relative_embedding=relative_embedding_flax_module, + low_rank_adaptation_scope=low_rank_adaptation_scope, + drop_path=drop_path, + transpose_batch_sequence=transpose_batch_sequence, + ) return praxis_p, flax_cls - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): self.attrs = attrs praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) - @pytest.mark.parametrize('fp8_format', FP8_FORMATS) - def test_forward_backward_fp8(self, - data_shape, - dtype, - attrs, - fp8_format, - rtol=1e-05, - atol=1e-08): + @pytest.mark.parametrize("data_shape", DATA_SHAPE) + @pytest.mark.parametrize("dtype", DTYPE) + @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS) + @pytest.mark.parametrize("fp8_format", FP8_FORMATS) + def test_forward_backward_fp8( + self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08 + ): self.attrs = attrs ds = DelayedScaling(fp8_format=fp8_format) with fp8_autocast(enabled=True, fp8_recipe=ds): diff --git a/tests/jax/test_sanity_import.py b/tests/jax/test_sanity_import.py index cc07e2d63cf699b3f9d0c3a661cf2fa2f634d0a1..f47c2eb411cfffd1559f835281d1e8192a5c7ae6 100644 --- a/tests/jax/test_sanity_import.py +++ b/tests/jax/test_sanity_import.py @@ -3,4 +3,5 @@ # See LICENSE for license information. import transformer_engine.jax + print("OK") diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py index 484728ed71bca2333f212c271a96656e373983ec..4581cdc39ee5503fe89a50ee7e7cf72d7187f03a 100644 --- a/tests/jax/test_sharding.py +++ b/tests/jax/test_sharding.py @@ -8,25 +8,25 @@ from transformer_engine.jax.flax import extend_logical_axis_rules from transformer_engine.jax.sharding import global_shard_guard, MeshResource LOGICAL_RULES = [ - [(('a1', None), ('a2', 'ma2')), False], - [(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True], - [(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False], - [(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True], - [(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True], + [(("a1", None), ("a2", "ma2")), False], + [(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True], + [(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False], + [(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True], + [(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True], ] MeshS = [ MeshResource(), - MeshResource('data', None), - MeshResource(None, 'model'), - MeshResource('data', 'model') + MeshResource("data", None), + MeshResource(None, "model"), + MeshResource("data", "model"), ] class TestShardingSideAPI: - @pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES) - @pytest.mark.parametrize('sr', MeshS) + @pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES) + @pytest.mark.parametrize("sr", MeshS) def test_extend_logical_axis_rules(self, base_rules, need_assert, sr): with global_shard_guard(sr): try: diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py index 3df268cbb1c380efa1c1d502162cd2ffa68c5609..0cff5955fa4f3f71a5948379b51a2ae381d5d844 100644 --- a/tests/jax/test_softmax.py +++ b/tests/jax/test_softmax.py @@ -43,6 +43,7 @@ class SoftmaxRunner: """ Softmax runner """ + batch_size: int max_seqlen_q: int max_seqlen_kv: int @@ -57,14 +58,22 @@ class SoftmaxRunner: Jax softmax as the reference """ if mask is not None: - logits += lax.select(mask > 0, - jnp.full(mask.shape, -1e10).astype(logits.dtype), - jnp.full(mask.shape, 0.).astype(logits.dtype)) + logits += lax.select( + mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.0).astype(logits.dtype), + ) return nn.softmax(logits * scale_factor) def _is_support(self): - return is_softmax_kernel_available(self.softmax_type, self.batch_size, self.num_heads, - self.max_seqlen_q, self.max_seqlen_kv, self.dtype) + return is_softmax_kernel_available( + self.softmax_type, + self.batch_size, + self.num_heads, + self.max_seqlen_q, + self.max_seqlen_kv, + self.dtype, + ) def _setup_inputs(self): key = jax.random.PRNGKey(0) @@ -73,7 +82,7 @@ class SoftmaxRunner: logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv) mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) - self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.) + self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0) match self.softmax_type: case SoftmaxType.SCALED: @@ -81,7 +90,7 @@ class SoftmaxRunner: case SoftmaxType.SCALED_MASKED: self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8) case SoftmaxType.SCALED_UPPER_TRIANG_MASKED: - self.mask = (1. - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8) + self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8) case _: raise ValueError(f"Unknown {self.softmax_type=}") @@ -108,18 +117,24 @@ class SoftmaxRunner: args = [self.logits, self.mask] kwargs = { - 'scale_factor': self.scale_factor, - 'softmax_type': self.softmax_type, + "scale_factor": self.scale_factor, + "softmax_type": self.softmax_type, } # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( - value_and_grad(lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), - (0,))) + value_and_grad( + lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,) + ) + ) jitted_reference = jit( value_and_grad( - lambda logits, *args: grad_func(__class__.reference_softmax, self.logits, *args, ** - kwargs), (0,))) + lambda logits, *args: grad_func( + __class__.reference_softmax, self.logits, *args, **kwargs + ), + (0,), + ) + ) primitive_out, (primitive_grad_logits,) = jitted_primitive(*args) reference_out, (reference_grad_logits,) = jitted_reference(*args) @@ -128,21 +143,30 @@ class SoftmaxRunner: assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype) -@pytest.mark.parametrize('b, s_q, s_kv, h', [ - pytest.param(8, 16, 16, 16, id='8-16-16-16'), - pytest.param(8, 512, 512, 16, id='8-512-512-16'), - pytest.param(2, 8, 16384, 8, id='2-8-16384-8') -]) -@pytest.mark.parametrize('scale_factor', [0.125]) -@pytest.mark.parametrize('softmax_type', [ - pytest.param(SoftmaxType.SCALED, id='SCALED'), - pytest.param(SoftmaxType.SCALED_MASKED, id='SCALED_MASKED'), - pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id='SCALED_UPPER_TRIANG_MASKED') -]) -@pytest.mark.parametrize('dtype', [ - pytest.param(jnp.bfloat16, id="BF16"), - pytest.param(jnp.float16, id="FP16"), -]) +@pytest.mark.parametrize( + "b, s_q, s_kv, h", + [ + pytest.param(8, 16, 16, 16, id="8-16-16-16"), + pytest.param(8, 512, 512, 16, id="8-512-512-16"), + pytest.param(2, 8, 16384, 8, id="2-8-16384-8"), + ], +) +@pytest.mark.parametrize("scale_factor", [0.125]) +@pytest.mark.parametrize( + "softmax_type", + [ + pytest.param(SoftmaxType.SCALED, id="SCALED"), + pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"), + pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param(jnp.bfloat16, id="BF16"), + pytest.param(jnp.float16, id="FP16"), + ], +) class TestSoftmax: """ Test transformer_engine.jax.softmax.softmax diff --git a/tests/jax/utils.py b/tests/jax/utils.py index ee933005d05a927d3c5f53386c9e37d4cb5c476f..798c2a82bae7136952bb0233a1f716ec20676c63 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -24,8 +24,9 @@ PRNGKey = Any Shape = Tuple[int, ...] DType = jnp.dtype Array = Any -PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, - lax.Precision]] +PrecisionLike = Union[ + None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] +] Initializer = Callable[[PRNGKey, Shape, DType], Array] @@ -56,7 +57,7 @@ def _canonicalize_tuple(x): def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable: """Convert a string to an activation function.""" - if fn_or_string == 'linear': + if fn_or_string == "linear": return lambda x: x if isinstance(fn_or_string, str): return getattr(nn, fn_or_string) @@ -68,17 +69,18 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla def combine_biases(*masks: Optional[Array]): """Combine attention biases. - Args: - *masks: set of attention bias arguments to combine, some can be None. + Args: + *masks: set of attention bias arguments to combine, some can be None. - Returns: - Combined mask, reduced by summation, returns None if no masks given. - """ + Returns: + Combined mask, reduced by summation, returns None if no masks given. + """ masks = [m for m in masks if m is not None] if not masks: return None - assert all(map(lambda x: x.ndim == masks[0].ndim, - masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + assert all( + map(lambda x: x.ndim == masks[0].ndim, masks) + ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}" mask, *other_masks = masks for other_mask in other_masks: mask = mask + other_mask @@ -88,7 +90,7 @@ def combine_biases(*masks: Optional[Array]): class DotProductAttention(nn.Module): transpose_batch_sequence: bool = True scale_attn_logits: bool = True - dropout_rate: float = 0. + dropout_rate: float = 0.0 dtype: DType = jnp.float32 float32_logits: bool = False """Computes dot-product attention given query, key, and value. @@ -105,12 +107,14 @@ class DotProductAttention(nn.Module): """ @nn.compact - def __call__(self, - query: Array, - key: Array, - value: Array, - bias: Optional[Array] = None, - deterministic: bool = False): + def __call__( + self, + query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + deterministic: bool = False, + ): """ Args: query: queries for calculating attention with shape of `[batch, q_length, @@ -127,14 +131,15 @@ class DotProductAttention(nn.Module): Returns: Output of shape `[batch, length, num_heads, v_depth_per_head]`. """ - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank." batch_dim = 1 if self.transpose_batch_sequence else 0 - assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( - 'q, k, v batch dims must match.') + assert ( + query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim] + ), "q, k, v batch dims must match." sequence_dim = 0 if self.transpose_batch_sequence else 1 - assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' - assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.' - assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.' + assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match." + assert key.shape[-2] == value.shape[-2], "k, v num_heads must match." + assert query.shape[-1] == key.shape[-1], "q, k head_dim must match." if self.scale_attn_logits: head_dim = query.shape[-1] @@ -153,9 +158,9 @@ class DotProductAttention(nn.Module): grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) if self.transpose_batch_sequence: - attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) + attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key) else: - attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) + attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key) # reshape back to normal DPA shape for bias/softmax/dropout b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape @@ -170,37 +175,37 @@ class DotProductAttention(nn.Module): attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype) # Apply attention dropout. - if not deterministic and self.dropout_rate > 0.: + if not deterministic and self.dropout_rate > 0.0: keep_prob = 1.0 - self.dropout_rate # T5 broadcasts along the "length" dim, but unclear which one that # corresponds to in positional dimensions here, assuming query dim. dropout_shape = list(attn_weights.shape) - dropout_rng = self.make_rng('dropout') + dropout_rng = self.make_rng("dropout") keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) - multiplier = (keep.astype(attn_weights.dtype) / - jnp.asarray(keep_prob, dtype=self.dtype)) + multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype) attn_weights = attn_weights * multiplier attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) # Take the linear combination of `value`. if self.transpose_batch_sequence: - return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) + return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape) - return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) + return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape) class DenseGeneral(nn.Module): """A linear transformation with flexible axes and FP8 support. - Attributes: - features: tuple with numbers of output features. - axis: tuple with axes to apply the transformation on. - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - use_bias: whether to add a bias to the output (default: False). - bias_init: initializer function for the bias vector. + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add a bias to the output (default: False). + bias_init: initializer function for the bias vector. """ + features: Union[Iterable[int], int] axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 @@ -212,7 +217,7 @@ class DenseGeneral(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') + self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") super().__post_init__() @nn.compact @@ -233,21 +238,17 @@ class DenseGeneral(nn.Module): kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features)) - kernel = nn_partitioning.param_with_axes('kernel', - self.kernel_init, - kernel_param_shape, - jnp.float32, - axes=self.kernel_axes) + kernel = nn_partitioning.param_with_axes( + "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + ) kernel = jnp.asarray(kernel, self.dtype) kernel = jnp.reshape(kernel, kernel_shape) if self.use_bias: - bias = nn_partitioning.param_with_axes('bias', - self.bias_init, - self.features, - jnp.float32, - axes=self.bias_axes) + bias = nn_partitioning.param_with_axes( + "bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes + ) bias = bias.astype(self.dtype) else: bias = None @@ -264,18 +265,19 @@ class DenseGeneral(nn.Module): class MlpBlock(nn.Module): """Transformer MLP / feed-forward block. - Attributes: - intermediate_dim: Shared dimension of hidden layers. - activations: Type of activations for each layer. Each element is either - 'linear', a string function name in flax.linen, or a function. - kernel_init: Kernel function, passed to the dense layers. - deterministic: Whether the dropout layers should be deterministic. - intermediate_dropout_rate: Dropout rate used after the intermediate layers. - dtype: Type for the dense layer. - """ + Attributes: + intermediate_dim: Shared dimension of hidden layers. + activations: Type of activations for each layer. Each element is either + 'linear', a string function name in flax.linen, or a function. + kernel_init: Kernel function, passed to the dense layers. + deterministic: Whether the dropout layers should be deterministic. + intermediate_dropout_rate: Dropout rate used after the intermediate layers. + dtype: Type for the dense layer. + """ + transpose_batch_sequence: bool intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable]] = ('relu',) + activations: Sequence[Union[str, Callable]] = ("relu",) kernel_init: Initializer = None intermediate_dropout_rate: float = 0.1 intermediate_dropout_dims: Sequence[int] = () @@ -285,7 +287,7 @@ class MlpBlock(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') + self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") super().__post_init__() @nn.compact @@ -296,49 +298,57 @@ class MlpBlock(nn.Module): activations = [] if self.fuse_wi: - dense_name = 'wi' + dense_name = "wi" num_activations = len(self.activations) - x = DenseGeneral(self.intermediate_dim * num_activations, - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'mlp'), - use_bias=self.use_bias, - bias_axes=('mlp'), - name=dense_name)(inputs) + x = DenseGeneral( + self.intermediate_dim * num_activations, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "mlp"), + use_bias=self.use_bias, + bias_axes="mlp", + name=dense_name, + )(inputs) x = jnp.split(x, num_activations, axis=-1) for idx, act_fn in enumerate(self.activations): x_i = _convert_to_activation_function(act_fn)(x[idx]) activations.append(x_i) else: for idx, act_fn in enumerate(self.activations): - dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' - x = DenseGeneral(self.intermediate_dim, - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'mlp'), - use_bias=self.use_bias, - bias_axes=('mlp'), - name=dense_name)(inputs) + dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" + x = DenseGeneral( + self.intermediate_dim, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "mlp"), + use_bias=self.use_bias, + bias_axes="mlp", + name=dense_name, + )(inputs) x = _convert_to_activation_function(act_fn)(x) activations.append(x) # Take elementwise product of above intermediate activations. x = functools.reduce(operator.mul, activations) # Apply dropout and final dense output projection. - x = nn.Dropout(rate=self.intermediate_dropout_rate, - broadcast_dims=self.intermediate_dropout_dims)( - x, deterministic=deterministic) # Broadcast along length. + x = nn.Dropout( + rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims + )( + x, deterministic=deterministic + ) # Broadcast along length. if self.transpose_batch_sequence: - x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp')) + x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp")) else: - x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'mlp')) - output = DenseGeneral(inputs.shape[-1], - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('mlp', 'embed'), - use_bias=self.use_bias, - bias_axes=('embed'), - name='wo')(x) + x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp")) + output = DenseGeneral( + inputs.shape[-1], + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=("mlp", "embed"), + use_bias=self.use_bias, + bias_axes="embed", + name="wo", + )(x) return output @@ -351,7 +361,7 @@ def apply_rotary_pos_emb_alternate( embedding_dim = inputs.shape[-1] half_embedding_dim = embedding_dim // 2 fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim - timescale = min_timescale * (max_timescale / min_timescale)**fraction + timescale = min_timescale * (max_timescale / min_timescale) ** fraction timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1))) position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim))) sinusoid_inp = position / timescale @@ -386,7 +396,7 @@ def apply_rotary_pos_emb_consecutive( inputs_shifted_left, ) fraction = jnp.repeat(fraction, 2) - timescale = min_timescale * (max_timescale / min_timescale)**fraction + timescale = min_timescale * (max_timescale / min_timescale) ** fraction position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim))) @@ -415,89 +425,96 @@ class MultiHeadAttention(nn.Module): kernel_init: initializer for the kernel of the Dense layers. float32_logits: bool, if True then compute logits in float32 to avoid numerical issues with bfloat16. - """ + """ num_heads: int = 8 num_gqa_groups: int | None = None head_dim: int = 64 transpose_batch_sequence: bool = True dtype: DType = jnp.float32 - dropout_rate: float = 0. + dropout_rate: float = 0.0 kernel_init: Initializer = None - float32_logits: bool = False # computes logits in float32 for stability. + float32_logits: bool = False # computes logits in float32 for stability. scale_attn_logits: bool = False scaled_query_init: bool = True enable_rotary_pos_emb: bool = False - rotary_pos_emb_group_method: str = 'consecutive' + rotary_pos_emb_group_method: str = "consecutive" fuse_qkv: bool = True use_bias: bool = False def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') + self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @nn.compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - *, - decode: bool = False, - deterministic: bool = False) -> Array: + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + decode: bool = False, + deterministic: bool = False, + ) -> Array: """Applies multi-head dot product attention on the input data. - Projects the inputs into multi-headed query, key, and value vectors, - applies dot-product attention and project the results to an output vector. + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention and project the results to an output vector. - There are two modes: decoding and non-decoding (e.g., training). The mode is - determined by `decode` argument. For decoding, this method is called twice, - first to initialize the cache and then for an actual decoding process. The - two calls are differentiated by the presence of 'cached_key' in the variable - dict. In the cache initialization stage, the cache variables are initialized - as zeros and will be filled in the subsequent decoding process. + There are two modes: decoding and non-decoding (e.g., training). The mode is + determined by `decode` argument. For decoding, this method is called twice, + first to initialize the cache and then for an actual decoding process. The + two calls are differentiated by the presence of 'cached_key' in the variable + dict. In the cache initialization stage, the cache variables are initialized + as zeros and will be filled in the subsequent decoding process. - In the cache initialization call, `inputs_q` has a shape [batch, length, - q_features] and `inputs_kv`: [batch, length, kv_features]. During the - incremental decoding stage, query, key and value all have the shape [batch, - 1, qkv_features] corresponding to a single step. + In the cache initialization call, `inputs_q` has a shape [batch, length, + q_features] and `inputs_kv`: [batch, length, kv_features]. During the + incremental decoding stage, query, key and value all have the shape [batch, + 1, qkv_features] corresponding to a single step. - Args: - inputs_q: input queries of shape `[batch, q_length, q_features]`. - inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. - mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. - bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. - decode: Whether to prepare and use an autoregressive cache. - deterministic: Disables dropout if set to True. + Args: + inputs_q: input queries of shape `[batch, q_length, q_features]`. + inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. + mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. + bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. + decode: Whether to prepare and use an autoregressive cache. + deterministic: Disables dropout if set to True. - Returns: - output of shape `[batch, length, q_features]`. - """ - q_projection = functools.partial(DenseGeneral, - axis=-1, - features=self.num_heads * self.head_dim, - kernel_axes=('embed', 'joined_kv'), - use_bias=self.use_bias, - bias_axes=('joined_kv'), - dtype=self.dtype) - - kv_projection = functools.partial(DenseGeneral, - axis=-1, - features=self.num_gqa_groups * self.head_dim, - kernel_axes=('embed', 'joined_kv'), - use_bias=self.use_bias, - bias_axes=('joined_kv'), - dtype=self.dtype) + Returns: + output of shape `[batch, length, q_features]`. + """ + q_projection = functools.partial( + DenseGeneral, + axis=-1, + features=self.num_heads * self.head_dim, + kernel_axes=("embed", "joined_kv"), + use_bias=self.use_bias, + bias_axes="joined_kv", + dtype=self.dtype, + ) + + kv_projection = functools.partial( + DenseGeneral, + axis=-1, + features=self.num_gqa_groups * self.head_dim, + kernel_axes=("embed", "joined_kv"), + use_bias=self.use_bias, + bias_axes="joined_kv", + dtype=self.dtype, + ) # NOTE: T5 does not explicitly rescale the attention logits by # 1/sqrt(depth_kq)! This is folded into the initializers of the # linear transformations, which is equivalent under Adafactor depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) - query_init = lambda *args: self.kernel_init(*args) / (depth_scaling - if self.scaled_query_init else 1.0) + query_init = lambda *args: self.kernel_init(*args) / ( + depth_scaling if self.scaled_query_init else 1.0 + ) # Project inputs_q to multi-headed q/k/v # dimensions are then [batch, length, num_heads, head_dim] @@ -515,39 +532,45 @@ class MultiHeadAttention(nn.Module): return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype) - is_self_attn = (inputs_q is inputs_kv) - is_gqa = (self.num_heads != self.num_gqa_groups) - is_qkvpack = (is_self_attn and not is_gqa) + is_self_attn = inputs_q is inputs_kv + is_gqa = self.num_heads != self.num_gqa_groups + is_qkvpack = is_self_attn and not is_gqa if self.fuse_qkv: if is_qkvpack: - qkv_proj = DenseGeneral(axis=-1, - features=self.num_heads * self.head_dim * 3, - kernel_axes=('embed', 'joined_kv'), - kernel_init=qkv_init, - use_bias=self.use_bias, - bias_axes=('joined_kv'), - name='qkv', - dtype=self.dtype)(inputs_kv) + qkv_proj = DenseGeneral( + axis=-1, + features=self.num_heads * self.head_dim * 3, + kernel_axes=("embed", "joined_kv"), + kernel_init=qkv_init, + use_bias=self.use_bias, + bias_axes="joined_kv", + name="qkv", + dtype=self.dtype, + )(inputs_kv) query, key, value = jnp.split( - qkv_proj, [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2], - axis=-1) + qkv_proj, + [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2], + axis=-1, + ) else: - query = q_projection(kernel_init=query_init, name='query')(inputs_q) - - kv_proj = DenseGeneral(axis=-1, - features=self.num_gqa_groups * self.head_dim * 2, - kernel_axes=('embed', 'joined_kv'), - kernel_init=self.kernel_init, - use_bias=self.use_bias, - bias_axes=('joined_kv'), - name='kv', - dtype=self.dtype)(inputs_kv) + query = q_projection(kernel_init=query_init, name="query")(inputs_q) + + kv_proj = DenseGeneral( + axis=-1, + features=self.num_gqa_groups * self.head_dim * 2, + kernel_axes=("embed", "joined_kv"), + kernel_init=self.kernel_init, + use_bias=self.use_bias, + bias_axes="joined_kv", + name="kv", + dtype=self.dtype, + )(inputs_kv) key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1) else: - query = q_projection(kernel_init=query_init, name='query')(inputs_q) - key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) - value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) + query = q_projection(kernel_init=query_init, name="query")(inputs_q) + key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv) + value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv) if self.enable_rotary_pos_emb: batch_dim = 1 if self.transpose_batch_sequence else 0 @@ -556,7 +579,7 @@ class MultiHeadAttention(nn.Module): q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim) k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim) - if self.rotary_pos_emb_group_method == 'alternate': + if self.rotary_pos_emb_group_method == "alternate": apply_rotary_pos_emb = apply_rotary_pos_emb_alternate else: apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive @@ -571,33 +594,40 @@ class MultiHeadAttention(nn.Module): value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) if self.transpose_batch_sequence: - query = nn_partitioning.with_sharding_constraint(query, - ('length', 'batch', 'heads', 'kv')) - key = nn_partitioning.with_sharding_constraint(key, ('length', 'batch', 'heads', 'kv')) - value = nn_partitioning.with_sharding_constraint(value, - ('length', 'batch', 'heads', 'kv')) + query = nn_partitioning.with_sharding_constraint( + query, ("length", "batch", "heads", "kv") + ) + key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv")) + value = nn_partitioning.with_sharding_constraint( + value, ("length", "batch", "heads", "kv") + ) else: - query = nn_partitioning.with_sharding_constraint(query, - ('batch', 'length', 'heads', 'kv')) - key = nn_partitioning.with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) - value = nn_partitioning.with_sharding_constraint(value, - ('batch', 'length', 'heads', 'kv')) + query = nn_partitioning.with_sharding_constraint( + query, ("batch", "length", "heads", "kv") + ) + key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv")) + value = nn_partitioning.with_sharding_constraint( + value, ("batch", "length", "heads", "kv") + ) if decode: # Detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable('cache', 'cached_key') + is_initialized = self.has_variable("cache", "cached_key") # The key and value have dimension [batch, length, num_heads, head_dim], # but we cache them as [batch, num_heads, head_dim, length] as a TPU # fusion optimization. This also enables the "scatter via one-hot # broadcast" trick, which means we do a one-hot broadcast instead of a # scatter/gather operations, resulting in a 3-4x speedup in practice. swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) - cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape), - key.dtype) - cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape), - value.dtype) - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.int32)) + cached_key = self.variable( + "cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype + ) + cached_value = self.variable( + "cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype + ) + cache_index = self.variable( + "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32) + ) if is_initialized: batch, num_heads, head_dim, length = cached_key.value.shape # During fast autoregressive decoding, we feed one position at a time, @@ -606,8 +636,9 @@ class MultiHeadAttention(nn.Module): expected_shape = (batch, 1, num_heads, head_dim) if expected_shape != query.shape: raise ValueError( - 'Autoregressive cache shape error, ' - f"expected query shape {expected_shape} instead got {query.shape}.") + "Autoregressive cache shape error, " + f"expected query shape {expected_shape} instead got {query.shape}." + ) # Create a OHE of the current index. NOTE: the index is increased below. cur_index = cache_index.value @@ -638,11 +669,13 @@ class MultiHeadAttention(nn.Module): jnp.logical_not(mask), jnp.broadcast_to( jnp.arange(length) <= cur_index, - # (1, 1, length) represent (head dim, query length, key length) - # query length is 1 because during decoding we deal with one - # index. - # The same mask is applied to all batch elements and heads. - (batch, 1, 1, length))) + # (1, 1, length) represent (head dim, query length, key length) + # query length is 1 because during decoding we deal with one + # index. + # The same mask is applied to all batch elements and heads. + (batch, 1, 1, length), + ), + ) # Grab the correct relative attention bias during decoding. This is # only required during single step decoding. @@ -650,15 +683,18 @@ class MultiHeadAttention(nn.Module): # The bias is a full attention matrix, but during decoding we only # have to take a slice of it. # This is equivalent to bias[..., cur_index:cur_index+1, :]. - bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), - jnp.reshape(cur_index, (-1)), 1, -2) + bias = dynamic_vector_slice_in_dim( + jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2 + ) # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias - attention_bias = lax.select(mask > 0, - jnp.full(mask.shape, 0.).astype(self.dtype), - jnp.full(mask.shape, -1e10).astype(self.dtype)) + attention_bias = lax.select( + mask > 0, + jnp.full(mask.shape, 0.0).astype(self.dtype), + jnp.full(mask.shape, -1e10).astype(self.dtype), + ) else: attention_bias = None @@ -667,41 +703,41 @@ class MultiHeadAttention(nn.Module): attention_bias = combine_biases(attention_bias, bias) # Apply attention. - x = DotProductAttention(transpose_batch_sequence=self.transpose_batch_sequence, - scale_attn_logits=self.scale_attn_logits, - dropout_rate=self.dropout_rate, - dtype=self.dtype, - float32_logits=self.float32_logits)(query, - key, - value, - bias=attention_bias, - deterministic=deterministic) + x = DotProductAttention( + transpose_batch_sequence=self.transpose_batch_sequence, + scale_attn_logits=self.scale_attn_logits, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + float32_logits=self.float32_logits, + )(query, key, value, bias=attention_bias, deterministic=deterministic) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) if self.transpose_batch_sequence: - x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'joined_kv')) + x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv")) else: - x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'joined_kv')) + x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv")) # Back to the original inputs dimensions. out = DenseGeneral( - features=inputs_q.shape[-1], # output dim is set to the input dim. + features=inputs_q.shape[-1], # output dim is set to the input dim. axis=-1, kernel_init=self.kernel_init, - kernel_axes=('joined_kv', 'embed'), + kernel_axes=("joined_kv", "embed"), use_bias=self.use_bias, - bias_axes=('embed'), + bias_axes="embed", dtype=self.dtype, - name='out')(x) + name="out", + )(x) return out class LayerNorm(nn.Module): """T5 Layer normalization operating on the last axis of the input data.""" + epsilon: float = 1e-6 dtype: Any = jnp.float32 - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" zero_centered_gamma: bool = False scale_init: Initializer = None bias_init: Initializer = nn.initializers.zeros @@ -721,29 +757,27 @@ class LayerNorm(nn.Module): x = jnp.asarray(x, jnp.float32) features = x.shape[-1] - scale = nn_partitioning.param_with_axes('scale', - self.scale_init, (features,), - jnp.float32, - axes=('embed',)) + scale = nn_partitioning.param_with_axes( + "scale", self.scale_init, (features,), jnp.float32, axes=("embed",) + ) scale = jnp.asarray(scale, self.dtype) - if self.layernorm_type == 'layernorm': + if self.layernorm_type == "layernorm": mean = jnp.mean(x, axis=-1, keepdims=True) var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True) y = (x - mean) * lax.rsqrt(var + self.epsilon) - bias = nn_partitioning.param_with_axes('ln_bias', - self.bias_init, (features,), - jnp.float32, - axes=('embed',)) + bias = nn_partitioning.param_with_axes( + "ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",) + ) bias = jnp.asarray(bias, self.dtype) if not self.zero_centered_gamma: z = y * scale + bias else: - z = y * (scale + 1.) + bias + z = y * (scale + 1.0) + bias else: - assert self.layernorm_type == 'rmsnorm' + assert self.layernorm_type == "rmsnorm" assert not self.zero_centered_gamma mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) y = x * lax.rsqrt(mean2 + self.epsilon) @@ -755,16 +789,17 @@ class LayerNorm(nn.Module): class RelativePositionBiases(nn.Module): """Adds T5-style relative positional embeddings to the attention logits. - Attributes: - num_buckets: Number of buckets to bucket distances between key and query - positions into. - max_distance: Maximum distance before everything is lumped into the last - distance bucket. - num_heads: Number of heads in the attention layer. Each head will get a - different relative position weighting. - dtype: Type of arrays through this module. - embedding_init: initializer for relative embedding table. - """ + Attributes: + num_buckets: Number of buckets to bucket distances between key and query + positions into. + max_distance: Maximum distance before everything is lumped into the last + distance bucket. + num_heads: Number of heads in the attention layer. Each head will get a + different relative position weighting. + dtype: Type of arrays through this module. + embedding_init: initializer for relative embedding table. + """ + num_buckets: int max_distance: int num_heads: int @@ -772,33 +807,32 @@ class RelativePositionBiases(nn.Module): embedding_init: Callable[..., Array] = nn.linear.default_embed_init @staticmethod - def _relative_position_bucket(relative_position, - bidirectional=True, - num_buckets=32, - max_distance=128): + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): """Translate relative position to a bucket number for relative attention. - The relative position is defined as memory_position - query_position, i.e. - the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are - invalid. - We use smaller buckets for small absolute relative_position and larger - buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative - positions <=-max_distance map to the same bucket. This should allow for - more graceful generalization to longer sequences than the model has been - trained on. + The relative position is defined as memory_position - query_position, i.e. + the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are + invalid. + We use smaller buckets for small absolute relative_position and larger + buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative + positions <=-max_distance map to the same bucket. This should allow for + more graceful generalization to longer sequences than the model has been + trained on. - Args: - relative_position: an int32 array - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer + Args: + relative_position: an int32 array + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer - Returns: - a Tensor with the same shape as relative_position, containing int32 - values in the range [0, num_buckets) - """ + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """ ret = 0 n = -relative_position if bidirectional: @@ -811,8 +845,10 @@ class RelativePositionBiases(nn.Module): max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( - np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / - np.log(max_distance / max_exact) * (num_buckets - max_exact)).astype(np.int32) + np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) + / np.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).astype(np.int32) val_if_large = np.minimum(val_if_large, num_buckets - 1) ret += np.where(is_small, n, val_if_large) return ret @@ -821,27 +857,31 @@ class RelativePositionBiases(nn.Module): def __call__(self, qlen, klen, bidirectional=True): """Produce relative position embedding attention biases. - Args: - qlen: attention query length. - klen: attention key length. - bidirectional: whether to allow positive memory-query relative position - embeddings. + Args: + qlen: attention query length. + klen: attention key length. + bidirectional: whether to allow positive memory-query relative position + embeddings. - Returns: - output: `(1, len, q_len, k_len)` attention bias - """ + Returns: + output: `(1, len, q_len, k_len)` attention bias + """ context_position = np.arange(qlen, dtype=jnp.int32)[:, None] memory_position = np.arange(klen, dtype=jnp.int32)[None, :] - relative_position = memory_position - context_position # shape (qlen, klen) - rp_bucket = self._relative_position_bucket(relative_position, - bidirectional=bidirectional, - num_buckets=self.num_buckets, - max_distance=self.max_distance) + relative_position = memory_position - context_position # shape (qlen, klen) + rp_bucket = self._relative_position_bucket( + relative_position, + bidirectional=bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) relative_attention_bias = nn_partitioning.param_with_axes( - 'rel_embedding', - self.embedding_init, (self.num_heads, self.num_buckets), + "rel_embedding", + self.embedding_init, + (self.num_heads, self.num_buckets), jnp.float32, - axes=('heads', 'relpos_buckets')) + axes=("heads", "relpos_buckets"), + ) relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) # Instead of using a slow gather, we create a leading-dimension one-hot @@ -855,9 +895,8 @@ class RelativePositionBiases(nn.Module): values = lax.dot_general( relative_attention_bias, rp_bucket_one_hot, - ( - ((1,), (0,)), # rhs, lhs contracting dims - ((), ()))) # no batched dims + (((1,), (0,)), ((), ())), # rhs, lhs contracting dims + ) # no batched dims # Add a singleton batch dimension. # --> shape (1, num_heads, qlen, klen) return values[jnp.newaxis, ...] @@ -865,6 +904,7 @@ class RelativePositionBiases(nn.Module): class EncoderLayer(nn.Module): """Transformer encoder layer.""" + enable_relative_embedding: bool = True relative_embedding: nn.Module = None num_attention_heads: int = 8 @@ -880,17 +920,17 @@ class EncoderLayer(nn.Module): scale_attn_logits: bool = False scaled_query_init: bool = True mlp_dim: int = 2048 - mlp_activations: Sequence[str] = ('relu',) + mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False dtype: Any = jnp.float32 apply_residual_connection_post_layernorm: bool = False - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" layernorm_epsilon: float = 1e-6 zero_centered_gamma: bool = False output_layernorm: bool = False drop_path: float = 0.0 enable_rotary_pos_emb: bool = False - rotary_pos_emb_group_method: str = 'consecutive' + rotary_pos_emb_group_method: str = "consecutive" fuse_qkv_params: bool = True fuse_mlp_wi: bool = True self_attn_bias_type: Any = None @@ -903,20 +943,21 @@ class EncoderLayer(nn.Module): @nn.compact def __call__(self, inputs, encoder_mask=None, deterministic=False): - del self.self_attn_mask_type # dummy, just align to TE's impl + del self.self_attn_mask_type # dummy, just align to TE's impl # Relative position embedding as attention biases. sequence_dim = 0 if self.transpose_batch_sequence else 1 batch_dim = 1 - sequence_dim if self.enable_relative_embedding: if self.relative_embedding is None: - rel_emb = RelativePositionBiases(num_buckets=32, - max_distance=128, - num_heads=self.num_attention_heads, - dtype=self.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform'), - name='relpos_bias') + rel_emb = RelativePositionBiases( + num_buckets=32, + max_distance=128, + num_heads=self.num_attention_heads, + dtype=self.dtype, + embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), + name="relpos_bias", + ) else: rel_emb = self.relative_embedding encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True) @@ -928,11 +969,13 @@ class EncoderLayer(nn.Module): if not self.output_layernorm: # Attention block. - x = LayerNorm(layernorm_type=self.layernorm_type, - epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - dtype=self.dtype, - name="pre_attention_layer_norm")(inputs) + x = LayerNorm( + layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, + zero_centered_gamma=self.zero_centered_gamma, + dtype=self.dtype, + name="pre_attention_layer_norm", + )(inputs) if self.apply_residual_connection_post_layernorm: residual = x @@ -940,39 +983,41 @@ class EncoderLayer(nn.Module): x = inputs # [batch, length, emb_dim] -> [batch, length, emb_dim] - x = MultiHeadAttention(num_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - dtype=self.dtype, - head_dim=self.head_dim, - transpose_batch_sequence=self.transpose_batch_sequence, - dropout_rate=self.attention_dropout, - float32_logits=self.float32_attention_logits, - scale_attn_logits=self.scale_attn_logits, - scaled_query_init=self.scaled_query_init, - fuse_qkv=self.fuse_qkv_params, - enable_rotary_pos_emb=self.enable_rotary_pos_emb, - rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, - use_bias=self.use_bias, - name='attention')(x, - x, - encoder_mask, - encoder_bias, - deterministic=deterministic) - x = nn.Dropout(rate=self.hidden_dropout, - broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic) + x = MultiHeadAttention( + num_heads=self.num_attention_heads, + num_gqa_groups=self.num_gqa_groups, + dtype=self.dtype, + head_dim=self.head_dim, + transpose_batch_sequence=self.transpose_batch_sequence, + dropout_rate=self.attention_dropout, + float32_logits=self.float32_attention_logits, + scale_attn_logits=self.scale_attn_logits, + scaled_query_init=self.scaled_query_init, + fuse_qkv=self.fuse_qkv_params, + enable_rotary_pos_emb=self.enable_rotary_pos_emb, + rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + use_bias=self.use_bias, + name="attention", + )(x, x, encoder_mask, encoder_bias, deterministic=deterministic) + x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( + x, deterministic=deterministic + ) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) - x = nn.Dropout(rate=self.drop_path, - broadcast_dims=drop_path_shape)(x, deterministic=deterministic) + x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)( + x, deterministic=deterministic + ) x = x + residual # MLP block. residual = x - y = LayerNorm(layernorm_type=self.layernorm_type, - epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - dtype=self.dtype, - name='pre_mlp_layer_norm')(x) + y = LayerNorm( + layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, + zero_centered_gamma=self.zero_centered_gamma, + dtype=self.dtype, + name="pre_mlp_layer_norm", + )(x) if self.apply_residual_connection_post_layernorm: residual = y @@ -987,27 +1032,32 @@ class EncoderLayer(nn.Module): use_bias=self.use_bias, dtype=self.dtype, fuse_wi=self.fuse_mlp_wi, - name='mlp', + name="mlp", )(y, deterministic=deterministic) - y = nn.Dropout(rate=self.hidden_dropout, - broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic) + y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( + y, deterministic=deterministic + ) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim) - y = nn.Dropout(rate=self.drop_path, - broadcast_dims=drop_path_shape)(y, deterministic=deterministic) + y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)( + y, deterministic=deterministic + ) y = y + residual if self.output_layernorm: - y = LayerNorm(layernorm_type=self.layernorm_type, - epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - dtype=self.dtype, - name="output_layernorm")(y) + y = LayerNorm( + layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, + zero_centered_gamma=self.zero_centered_gamma, + dtype=self.dtype, + name="output_layernorm", + )(y) return y class DecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + enable_relative_embedding: bool = True relative_embedding: nn.Module = None num_attention_heads: int = 8 @@ -1023,17 +1073,17 @@ class DecoderLayer(nn.Module): scale_attn_logits: bool = False scaled_query_init: bool = True mlp_dim: int = 2048 - mlp_activations: Sequence[str] = ('relu',) + mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False dtype: Any = jnp.float32 apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" layernorm_epsilon: float = 1e-6 zero_centered_gamma: bool = False drop_path: float = 0.0 enable_rotary_pos_emb: bool = False - rotary_pos_emb_group_method: str = 'consecutive' + rotary_pos_emb_group_method: str = "consecutive" fuse_qkv_params: bool = True fuse_mlp_wi: bool = True self_attn_bias_type: Any = None @@ -1045,15 +1095,17 @@ class DecoderLayer(nn.Module): super().__post_init__() @nn.compact - def __call__(self, - inputs, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - deterministic=False, - decode=False, - max_decode_length=None): - del self.self_attn_mask_type # dummy, just align to TE's impl + def __call__( + self, + inputs, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + deterministic=False, + decode=False, + max_decode_length=None, + ): + del self.self_attn_mask_type # dummy, just align to TE's impl # Relative position embedding as attention biases. sequence_dim = 0 if self.transpose_batch_sequence else 1 batch_dim = 1 - sequence_dim @@ -1061,13 +1113,14 @@ class DecoderLayer(nn.Module): if self.enable_relative_embedding: l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim] if self.relative_embedding is None: - rel_emb = RelativePositionBiases(num_buckets=32, - max_distance=128, - num_heads=self.num_attention_heads, - dtype=self.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform'), - name='relpos_bias') + rel_emb = RelativePositionBiases( + num_buckets=32, + max_distance=128, + num_heads=self.num_attention_heads, + dtype=self.dtype, + embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), + name="relpos_bias", + ) else: rel_emb = self.relative_embedding decoder_bias = rel_emb(l, l, False) @@ -1079,11 +1132,13 @@ class DecoderLayer(nn.Module): if not self.output_layernorm: # Attention block. - x = LayerNorm(layernorm_type=self.layernorm_type, - epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - dtype=self.dtype, - name="pre_self_attention_layer_norm")(inputs) + x = LayerNorm( + layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, + zero_centered_gamma=self.zero_centered_gamma, + dtype=self.dtype, + name="pre_self_attention_layer_norm", + )(inputs) if self.apply_residual_connection_post_layernorm: residual = x @@ -1091,71 +1146,74 @@ class DecoderLayer(nn.Module): x = inputs # Self-attention block - x = MultiHeadAttention(num_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - dtype=self.dtype, - head_dim=self.head_dim, - transpose_batch_sequence=self.transpose_batch_sequence, - dropout_rate=self.attention_dropout, - float32_logits=self.float32_attention_logits, - scale_attn_logits=self.scale_attn_logits, - scaled_query_init=self.scaled_query_init, - enable_rotary_pos_emb=self.enable_rotary_pos_emb, - rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, - fuse_qkv=self.fuse_qkv_params, - use_bias=self.use_bias, - name='self_attention')(x, - x, - decoder_mask, - decoder_bias, - deterministic=deterministic, - decode=decode) - x = nn.Dropout(rate=self.hidden_dropout, - broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic) + x = MultiHeadAttention( + num_heads=self.num_attention_heads, + num_gqa_groups=self.num_gqa_groups, + dtype=self.dtype, + head_dim=self.head_dim, + transpose_batch_sequence=self.transpose_batch_sequence, + dropout_rate=self.attention_dropout, + float32_logits=self.float32_attention_logits, + scale_attn_logits=self.scale_attn_logits, + scaled_query_init=self.scaled_query_init, + enable_rotary_pos_emb=self.enable_rotary_pos_emb, + rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + fuse_qkv=self.fuse_qkv_params, + use_bias=self.use_bias, + name="self_attention", + )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode) + x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( + x, deterministic=deterministic + ) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) - x = nn.Dropout(rate=self.drop_path, - broadcast_dims=drop_path_shape)(x, deterministic=deterministic) + x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)( + x, deterministic=deterministic + ) x = x + residual # Encoder-Decoder block. residual = x - y = LayerNorm(layernorm_type=self.layernorm_type, - epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - dtype=self.dtype, - name='pre_cross_attention_layer_norm')(x) + y = LayerNorm( + layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, + zero_centered_gamma=self.zero_centered_gamma, + dtype=self.dtype, + name="pre_cross_attention_layer_norm", + )(x) if self.apply_residual_connection_post_layernorm: residual = y - y = MultiHeadAttention(num_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - dtype=self.dtype, - head_dim=self.head_dim, - transpose_batch_sequence=self.transpose_batch_sequence, - dropout_rate=self.attention_dropout, - float32_logits=self.float32_attention_logits, - scale_attn_logits=self.scale_attn_logits, - scaled_query_init=self.scaled_query_init, - enable_rotary_pos_emb=self.enable_rotary_pos_emb, - rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, - fuse_qkv=self.fuse_qkv_params, - use_bias=self.use_bias, - name='encoder_decoder_attention')(y, - encoded, - encoder_decoder_mask, - deterministic=deterministic) - y = nn.Dropout(rate=self.hidden_dropout, - broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic) + y = MultiHeadAttention( + num_heads=self.num_attention_heads, + num_gqa_groups=self.num_gqa_groups, + dtype=self.dtype, + head_dim=self.head_dim, + transpose_batch_sequence=self.transpose_batch_sequence, + dropout_rate=self.attention_dropout, + float32_logits=self.float32_attention_logits, + scale_attn_logits=self.scale_attn_logits, + scaled_query_init=self.scaled_query_init, + enable_rotary_pos_emb=self.enable_rotary_pos_emb, + rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + fuse_qkv=self.fuse_qkv_params, + use_bias=self.use_bias, + name="encoder_decoder_attention", + )(y, encoded, encoder_decoder_mask, deterministic=deterministic) + y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( + y, deterministic=deterministic + ) y = y + residual # MLP block. residual = y - z = LayerNorm(layernorm_type=self.layernorm_type, - epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - dtype=self.dtype, - name='pre_mlp_layer_norm')(y) + z = LayerNorm( + layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, + zero_centered_gamma=self.zero_centered_gamma, + dtype=self.dtype, + name="pre_mlp_layer_norm", + )(y) if self.apply_residual_connection_post_layernorm: residual = z z = MlpBlock( @@ -1167,22 +1225,26 @@ class DecoderLayer(nn.Module): use_bias=self.use_bias, dtype=self.dtype, fuse_wi=self.fuse_mlp_wi, - name='mlp', + name="mlp", )(z, deterministic=deterministic) - z = nn.Dropout(rate=self.hidden_dropout, - broadcast_dims=self.hidden_dropout_dims)(z, deterministic=deterministic) + z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( + z, deterministic=deterministic + ) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim) - z = nn.Dropout(rate=self.drop_path, - broadcast_dims=drop_path_shape)(z, deterministic=deterministic) + z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)( + z, deterministic=deterministic + ) z = z + residual if self.output_layernorm: - z = LayerNorm(layernorm_type=self.layernorm_type, - epsilon=self.layernorm_epsilon, - zero_centered_gamma=self.zero_centered_gamma, - dtype=self.dtype, - name="output_layernorm")(z) + z = LayerNorm( + layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, + zero_centered_gamma=self.zero_centered_gamma, + dtype=self.dtype, + name="output_layernorm", + )(z) return z @@ -1261,15 +1323,18 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08): flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected) flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual) - for (expected_path, expected_value), (actual_path, - actual_value) in zip(flatten_expected, flatten_actual): + for (expected_path, expected_value), (actual_path, actual_value) in zip( + flatten_expected, flatten_actual + ): assert expected_path == actual_path key_str = jax.tree_util.keystr(expected_path) - assert_allclose(expected_value, - actual_value, - rtol=rtol, - atol=atol, - err_msg=f'Value of expected{key_str} and actual{key_str} is not close') + assert_allclose( + expected_value, + actual_value, + rtol=rtol, + atol=atol, + err_msg=f"Value of expected{key_str} and actual{key_str} is not close", + ) def dtype_tols( @@ -1323,7 +1388,7 @@ def dtype_tols( ) -def sync_params_values(dst, src, transformations, sep='/'): +def sync_params_values(dst, src, transformations, sep="/"): """ This function will reconstuct a tree with dst's tree_def/shape and src's value. transformations is a map that records the key mappings between dst and src. diff --git a/tests/paddle/dist_launcher.py b/tests/paddle/dist_launcher.py index eb8e8cb6627812a7e7ab2638f3f8f9b5f2de5bda..8c417b19303b6deb43bd7d6a4bc2a7ba80af962a 100644 --- a/tests/paddle/dist_launcher.py +++ b/tests/paddle/dist_launcher.py @@ -21,15 +21,15 @@ from paddle.distributed.utils.launch_utils import ( watch_local_trainers, ) -__all__ = ['TestDistributed'] +__all__ = ["TestDistributed"] def get_cluster_from_args(selected_gpus): """Get node information from selected GPUs""" - cluster_node_ips = '127.0.0.1' - node_ip = '127.0.0.1' + cluster_node_ips = "127.0.0.1" + node_ip = "127.0.0.1" - node_ips = [x.strip() for x in cluster_node_ips.split(',')] + node_ips = [x.strip() for x in cluster_node_ips.split(",")] node_ips.index(node_ip) @@ -47,7 +47,7 @@ def get_cluster_from_args(selected_gpus): def get_gpus(selected_gpus): """Get selected GPU string""" - selected_gpus = [x.strip() for x in selected_gpus.split(',')] + selected_gpus = [x.strip() for x in selected_gpus.split(",")] return selected_gpus @@ -86,7 +86,7 @@ def start_local_trainers( print(f"trainer proc env:{current_env}") - if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + if os.getenv("WITH_COVERAGE", "OFF") == "ON": cmd = "python -m coverage run --branch -p " + training_script else: cmd = "python -u " + training_script @@ -95,7 +95,9 @@ def start_local_trainers( fn = None - proc = subprocess.Popen(cmd.split(" ") + training_script_args, env=current_env) # pylint: disable=consider-using-with + proc = subprocess.Popen( + cmd.split(" ") + training_script_args, env=current_env + ) # pylint: disable=consider-using-with tp = TrainerProc() tp.proc = proc @@ -117,10 +119,10 @@ class TestDistributed(unittest.TestCase): allocator_strategy="auto_growth", ): """Run target file in subprocesses""" - if (not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0): + if not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0: return - selected_gpus = get_gpus('0,1') + selected_gpus = get_gpus("0,1") cluster = None pod = None diff --git a/tests/paddle/parallel_tests/amax_reduction.py b/tests/paddle/parallel_tests/amax_reduction.py index 093eb50263e76b6e40671125ad9fc4bb5d9ca387..c4605f121e89002e1c69a1872883c062f0949b18 100644 --- a/tests/paddle/parallel_tests/amax_reduction.py +++ b/tests/paddle/parallel_tests/amax_reduction.py @@ -27,7 +27,7 @@ class TestAmaxReduction(unittest.TestCase): def setUp(self): self.data_parallel_size = 2 self.init_dist_env() - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" paddle.set_default_dtype(self.global_dtype) def init_dist_env(self): @@ -83,5 +83,5 @@ class TestAmaxReduction(unittest.TestCase): assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/paddle/parallel_tests/attention_tp.py b/tests/paddle/parallel_tests/attention_tp.py index 31ce75f4c9429a3e5a7b87155ad193272b766234..e145f20b3901c24a1ff88b1191eb40c2ad120c9b 100644 --- a/tests/paddle/parallel_tests/attention_tp.py +++ b/tests/paddle/parallel_tests/attention_tp.py @@ -44,8 +44,8 @@ class TestAttentionTp(unittest.TestCase): self.num_heads = 16 self.q_seqlen = 128 self.kv_seqlen = 128 - self.mask_type = 'padding' - self.global_dtype = 'bfloat16' + self.mask_type = "padding" + self.global_dtype = "bfloat16" self.rtol = 5e-3 self.atol = 5e-3 self.eps = 1e-3 @@ -56,7 +56,7 @@ class TestAttentionTp(unittest.TestCase): inp, mask = inp_list if sequence_parallel: split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] + input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] else: input_parallel = inp with te.fp8_autocast(enabled=fp8_enabled): @@ -80,18 +80,20 @@ class TestAttentionTp(unittest.TestCase): self.num_heads, ) common_kwargs = { - 'layernorm_epsilon': self.eps, - 'attention_dropout': 0.0, - 'attn_mask_type': self.mask_type, - 'attention_type': 'self', + "layernorm_epsilon": self.eps, + "attention_dropout": 0.0, + "attn_mask_type": self.mask_type, + "attention_type": "self", "tp_group": self.tp_group, "input_layernorm": True, } - layer_tp = te.MultiHeadAttention(*common_args, - **common_kwargs, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel) + layer_tp = te.MultiHeadAttention( + *common_args, + **common_kwargs, + set_parallel_mode=True, + sequence_parallel=self.sequence_parallel, + ) layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False) def _get_total_weight(local_weight, tp_group, axis, interleave=False): @@ -102,12 +104,15 @@ class TestAttentionTp(unittest.TestCase): # Due to the interleaved qkv layout, need to concat on num_head # dimension for column parallel linear in MultiHeadAttention layer assert axis == 0 - assert [3 * self.hidden_size // self.world_size, - self.hidden_size] == partial_weight.shape + assert [ + 3 * self.hidden_size // self.world_size, + self.hidden_size, + ] == partial_weight.shape local_num_head = self.num_heads // self.world_size for idx, _ in enumerate(total_weight): total_weight[idx] = total_weight[idx].reshape( - [3, local_num_head, -1, self.hidden_size]) + [3, local_num_head, -1, self.hidden_size] + ) total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) else: total_weight = paddle.concat(total_weight, axis=axis) @@ -123,42 +128,47 @@ class TestAttentionTp(unittest.TestCase): weight_dst = _get_weight(layer_dst, weight_names) if partition_mode is None: total_weight = weight_src - elif partition_mode == 'column': - total_weight = _get_total_weight(weight_src, - tp_group=self.tp_group, - axis=0, - interleave=interleave) - elif partition_mode == 'row': + elif partition_mode == "column": + total_weight = _get_total_weight( + weight_src, tp_group=self.tp_group, axis=0, interleave=interleave + ) + elif partition_mode == "row": total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) else: raise ValueError(f"Partition Mode {partition_mode} is not supported.") - assert weight_dst.shape == total_weight.shape, \ - f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." + assert ( + weight_dst.shape == total_weight.shape + ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." weight_dst.copy_(total_weight, True) - copy_weight(layer_tp, layer_single, None, ['layernorm_qkv', 'ln_weight']) - copy_weight(layer_tp, layer_single, 'column', ['layernorm_qkv', 'weight'], interleave=True) - copy_weight(layer_tp, layer_single, 'row', ['proj', 'weight']) + copy_weight(layer_tp, layer_single, None, ["layernorm_qkv", "ln_weight"]) + copy_weight(layer_tp, layer_single, "column", ["layernorm_qkv", "weight"], interleave=True) + copy_weight(layer_tp, layer_single, "row", ["proj", "weight"]) if self.sequence_parallel: register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) - optimizer_single = paddle.optimizer.SGD(learning_rate=0.01, - parameters=layer_single.parameters()) + optimizer_single = paddle.optimizer.SGD( + learning_rate=0.01, parameters=layer_single.parameters() + ) layer_tp = fleet.distributed_model(layer_tp) optimizer_tp = fleet.distributed_optimizer(optimizer_tp) for _ in range(5): - inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size], - self.global_dtype) - mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), - dtype='bool') - loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8, - self.sequence_parallel) - loss_single, out_single = self._train_one_step(layer_single, [inp, mask], - optimizer_single, self.fp8) + inp = paddle.uniform( + [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype + ) + mask = paddle.zeros( + shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool" + ) + loss_tp, out_tp = self._train_one_step( + layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel + ) + loss_single, out_single = self._train_one_step( + layer_single, [inp, mask], optimizer_single, self.fp8 + ) assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) @@ -173,8 +183,8 @@ class TestAttentionTpFp8(TestAttentionTp): self.num_heads = 16 self.q_seqlen = 128 self.kv_seqlen = 128 - self.mask_type = 'padding' - self.global_dtype = 'bfloat16' + self.mask_type = "padding" + self.global_dtype = "bfloat16" self.rtol = 5e-2 self.atol = 5e-2 self.eps = 1e-3 @@ -192,8 +202,8 @@ class TestAttentionSp(TestAttentionTp): self.num_heads = 16 self.q_seqlen = 128 self.kv_seqlen = 128 - self.mask_type = 'padding' - self.global_dtype = 'bfloat16' + self.mask_type = "padding" + self.global_dtype = "bfloat16" self.rtol = 5e-3 self.atol = 5e-3 self.eps = 1e-3 @@ -211,8 +221,8 @@ class TestAttentionSpFp8(TestAttentionTp): self.num_heads = 16 self.q_seqlen = 128 self.kv_seqlen = 128 - self.mask_type = 'padding' - self.global_dtype = 'bfloat16' + self.mask_type = "padding" + self.global_dtype = "bfloat16" self.rtol = 5e-2 self.atol = 1e-1 self.eps = 1e-3 @@ -220,5 +230,5 @@ class TestAttentionSpFp8(TestAttentionTp): self.sequence_parallel = True -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/paddle/parallel_tests/group_sharding.py b/tests/paddle/parallel_tests/group_sharding.py index 3489b040a50efa7bba7ffe3a505a1a8d9a2fb94e..11060be38e0cf7c6442a80d6e35037dce9e33a7b 100644 --- a/tests/paddle/parallel_tests/group_sharding.py +++ b/tests/paddle/parallel_tests/group_sharding.py @@ -8,7 +8,8 @@ import unittest import paddle from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import ( - DygraphShardingOptimizer,) + DygraphShardingOptimizer, +) from utils import assert_allclose, set_random_seed import transformer_engine.paddle as te @@ -25,7 +26,7 @@ class TestGroupSharding(unittest.TestCase): def set_attr(self): """Set test configs""" self.sharding_degree = 2 - self.global_dtype = 'float32' + self.global_dtype = "float32" self.rtol = 1e-5 self.atol = 1e-5 self.batch_size = 16 @@ -57,11 +58,12 @@ class TestGroupSharding(unittest.TestCase): optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()) group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group() - class ShardingLevel: # pylint: disable=too-few-public-methods, + class ShardingLevel: # pylint: disable=too-few-public-methods, """Paddle sharding options""" - kStage1 = 'os' - kStage2 = 'os_g' - kStage3 = 'p_g_os' + + kStage1 = "os" + kStage2 = "os_g" + kStage3 = "p_g_os" level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2 model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel( @@ -104,8 +106,9 @@ class TestGroupSharding(unittest.TestCase): loss_pd = train_one_step(model_pd, inp, optimizer_pd) assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - assert len(optimizer_te.state_dict()) == 4, \ - "Expect each rank to hold 4 optimizer state entries." + assert ( + len(optimizer_te.state_dict()) == 4 + ), "Expect each rank to hold 4 optimizer state entries." def test_group_sharding_stage2(self): """Tests group sharding training""" @@ -141,8 +144,9 @@ class TestGroupSharding(unittest.TestCase): loss_pd = train_one_step(model_pd, inp, optimizer_pd) assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - assert len(optimizer_te.state_dict()) == 4, \ - "Expect each rank to hold 4 optimizer state entries." + assert ( + len(optimizer_te.state_dict()) == 4 + ), "Expect each rank to hold 4 optimizer state entries." def test_group_sharding_stage3(self): """Tests group sharding training""" @@ -174,11 +178,11 @@ class TestGroupSharding(unittest.TestCase): assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) for name, value in optimizer_te.state_dict().items(): - if name.endswith('w_0_moment1_0'): - assert value.numel() == \ - self.in_channels * self.out_channels // self.sharding_degree, \ - "Expect optimizer state to be sharded across trainers." + if name.endswith("w_0_moment1_0"): + assert ( + value.numel() == self.in_channels * self.out_channels // self.sharding_degree + ), "Expect optimizer state to be sharded across trainers." -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_linear_tp.py b/tests/paddle/parallel_tests/layernorm_linear_tp.py index 93142a571d5adb39e6b99286df04043c2efa0332..02295a71da46cacfb0a8bf94c826d1ae17262c39 100644 --- a/tests/paddle/parallel_tests/layernorm_linear_tp.py +++ b/tests/paddle/parallel_tests/layernorm_linear_tp.py @@ -42,22 +42,22 @@ class TestLayerNormLinearTp(unittest.TestCase): self.batch_size = 16 self.in_features = 32 self.out_features = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-3 self.atol = 1e-3 self.eps = 1e-3 self.fp8 = False self.sequence_parallel = False - def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False): + def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ['none', 'column', 'row'] - if split_input == 'column': + assert split_input in ["none", "column", "row"] + if split_input == "column": split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)] - elif split_input == 'row': + input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] + elif split_input == "row": split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] + input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] else: input_parallel = inp input_parallel.stop_gradient = False @@ -70,12 +70,12 @@ class TestLayerNormLinearTp(unittest.TestCase): loss.backward() optimizer.step() optimizer.clear_grad() - if split_input != 'none': + if split_input != "none": grad_input = [] paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == 'column': + if split_input == "column": grad_input = paddle.concat(grad_input, axis=1) - elif split_input == 'row': + elif split_input == "row": grad_input = paddle.concat(grad_input, axis=0) else: grad_input = input_parallel.grad @@ -88,14 +88,14 @@ class TestLayerNormLinearTp(unittest.TestCase): self.in_features, self.out_features, eps=self.eps, - parallel_mode='column', + parallel_mode="column", sequence_parallel=self.sequence_parallel, ) layer_pd = te.LayerNormLinear( self.in_features, self.out_features, eps=self.eps, - backend='paddle', + backend="paddle", ) # Get total weight total_weight = [] @@ -104,8 +104,9 @@ class TestLayerNormLinearTp(unittest.TestCase): total_weight = paddle.concat(total_weight, axis=0) layer_pd.weight.copy_(total_weight.T, True) - assert_shape(layer_te.weight, - [self.out_features // self.model_parallel_size, self.in_features]) + assert_shape( + layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features] + ) assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) @@ -121,8 +122,9 @@ class TestLayerNormLinearTp(unittest.TestCase): layer_te, inp, optimizer_te, - split_input='row' if self.sequence_parallel else 'none', - gather_output=True) + split_input="row" if self.sequence_parallel else "none", + gather_output=True, + ) loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) @@ -136,7 +138,7 @@ class TestLayerNormLinearTpFp8(TestLayerNormLinearTp): self.batch_size = 16 self.in_features = 32 self.out_features = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-2 self.atol = 1e-2 self.eps = 1e-3 @@ -152,7 +154,7 @@ class TestLayerNormLinearSp(TestLayerNormLinearTp): self.batch_size = 16 self.in_features = 32 self.out_features = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-3 self.atol = 1e-3 self.eps = 1e-3 @@ -168,7 +170,7 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp): self.batch_size = 16 self.in_features = 32 self.out_features = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-2 self.atol = 1e-2 self.eps = 1e-3 @@ -176,5 +178,5 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp): self.sequence_parallel = True -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_mlp_tp.py b/tests/paddle/parallel_tests/layernorm_mlp_tp.py index e5b391a2f0e3fb89f6700101996ee711a99f10ed..f23cfb9e3f33f4a8c4ce589a49b82166b0520b63 100644 --- a/tests/paddle/parallel_tests/layernorm_mlp_tp.py +++ b/tests/paddle/parallel_tests/layernorm_mlp_tp.py @@ -42,22 +42,22 @@ class TestLayerNormMLPTp(unittest.TestCase): self.batch_size = 16 self.hidden_size = 32 self.ffn_hidden_size = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-3 self.atol = 1e-3 self.eps = 1e-3 self.fp8 = False self.sequence_parallel = False - def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False): + def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ['none', 'column', 'row'] - if split_input == 'column': + assert split_input in ["none", "column", "row"] + if split_input == "column": split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)] - elif split_input == 'row': + input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] + elif split_input == "row": split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] + input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] else: input_parallel = inp input_parallel.stop_gradient = False @@ -71,12 +71,12 @@ class TestLayerNormMLPTp(unittest.TestCase): loss.backward() optimizer.step() optimizer.clear_grad() - if split_input != 'none': + if split_input != "none": grad_input = [] paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == 'column': + if split_input == "column": grad_input = paddle.concat(grad_input, axis=1) - elif split_input == 'row': + elif split_input == "row": grad_input = paddle.concat(grad_input, axis=0) else: grad_input = input_parallel.grad @@ -97,7 +97,7 @@ class TestLayerNormMLPTp(unittest.TestCase): ffn_hidden_size=self.ffn_hidden_size, eps=self.eps, set_parallel_mode=False, - backend='paddle', + backend="paddle", ) def _get_total_weight(local_weight, tp_group, axis): @@ -113,11 +113,15 @@ class TestLayerNormMLPTp(unittest.TestCase): layer_pd.fc1_weight.copy_(total_fc1_weight.T, True) layer_pd.fc2_weight.copy_(total_fc2_weight.T, True) - assert_shape(layer_te.fc1_weight, - [self.ffn_hidden_size // self.model_parallel_size, self.hidden_size]) + assert_shape( + layer_te.fc1_weight, + [self.ffn_hidden_size // self.model_parallel_size, self.hidden_size], + ) assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size]) - assert_shape(layer_te.fc2_weight, - [self.hidden_size, self.ffn_hidden_size // self.model_parallel_size]) + assert_shape( + layer_te.fc2_weight, + [self.hidden_size, self.ffn_hidden_size // self.model_parallel_size], + ) assert_shape(layer_te.fc2_bias, [self.hidden_size]) optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) @@ -133,8 +137,9 @@ class TestLayerNormMLPTp(unittest.TestCase): layer_te, inp, optimizer_te, - split_input='row' if self.sequence_parallel else 'none', - gather_output=self.sequence_parallel) + split_input="row" if self.sequence_parallel else "none", + gather_output=self.sequence_parallel, + ) loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) @@ -148,7 +153,7 @@ class TestLayerNormMLPTpFp8(TestLayerNormMLPTp): self.batch_size = 16 self.hidden_size = 32 self.ffn_hidden_size = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-2 self.atol = 1e-2 self.eps = 1e-3 @@ -164,7 +169,7 @@ class TestLayerNormMLPSp(TestLayerNormMLPTp): self.batch_size = 16 self.hidden_size = 32 self.ffn_hidden_size = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-3 self.atol = 1e-3 self.eps = 1e-3 @@ -180,7 +185,7 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp): self.batch_size = 16 self.hidden_size = 32 self.ffn_hidden_size = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-2 self.atol = 1e-2 self.eps = 1e-3 @@ -188,5 +193,5 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp): self.sequence_parallel = True -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/paddle/parallel_tests/linear_pp.py b/tests/paddle/parallel_tests/linear_pp.py index 26db9dd97701e141653d46947122b5d1bc9130cc..0e7e90611e31612bdb7cb5847a3b13c329ec48b2 100644 --- a/tests/paddle/parallel_tests/linear_pp.py +++ b/tests/paddle/parallel_tests/linear_pp.py @@ -23,14 +23,14 @@ class TELinear(te.Linear): """To pass is_first_microbatch""" def __init__(self, *args, **kwargs): - assert 'accumulate_steps' in kwargs - self.accumulate_steps = kwargs['accumulate_steps'] - del kwargs['accumulate_steps'] + assert "accumulate_steps" in kwargs + self.accumulate_steps = kwargs["accumulate_steps"] + del kwargs["accumulate_steps"] self._micro_batch_id = 0 super().__init__(*args, **kwargs) def forward(self, *args, **kwargs): - kwargs['is_first_microbatch'] = (self._micro_batch_id % self.accumulate_steps) == 0 + kwargs["is_first_microbatch"] = (self._micro_batch_id % self.accumulate_steps) == 0 if paddle.is_grad_enabled() and self.training: self._micro_batch_id += 1 return super().forward(*args, **kwargs) @@ -39,14 +39,16 @@ class TELinear(te.Linear): class TEPipelineModel(PipelineLayer): """Model for pipeline parallel test""" - def __init__(self, - in_features, - hidden_features, - weight_attrs, - use_te=True, - use_fp8=False, - accumulate_steps=1, - **kwargs): + def __init__( + self, + in_features, + hidden_features, + weight_attrs, + use_te=True, + use_fp8=False, + accumulate_steps=1, + **kwargs, + ): self.in_features = in_features self.hidden_features = hidden_features self.fp8 = use_fp8 @@ -56,19 +58,23 @@ class TEPipelineModel(PipelineLayer): Linear = TELinear if use_te else paddle.nn.Linear extra_kwargs = {} if use_te: - extra_kwargs['accumulate_steps'] = accumulate_steps + extra_kwargs["accumulate_steps"] = accumulate_steps model_desc = [ - LayerDesc(Linear, - self.in_features, - self.hidden_features, - weight_attr=weight_attrs[0], - **extra_kwargs), - LayerDesc(Linear, - self.hidden_features, - self.in_features, - weight_attr=weight_attrs[1], - **extra_kwargs), + LayerDesc( + Linear, + self.in_features, + self.hidden_features, + weight_attr=weight_attrs[0], + **extra_kwargs, + ), + LayerDesc( + Linear, + self.hidden_features, + self.in_features, + weight_attr=weight_attrs[1], + **extra_kwargs, + ), ] super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs) @@ -129,7 +135,7 @@ class TestLinearPipelineParallel(unittest.TestCase): self.micro_batch_size = 16 self.in_features = 32 self.hidden_features = 64 - self.global_dtype = 'float32' + self.global_dtype = "float32" self.rtol = 1e-5 self.atol = 1e-5 self.iter = 10 @@ -164,16 +170,18 @@ class TestLinearPipelineParallel(unittest.TestCase): # Check if model is split across ranks as expected for name, sublayer in pipe_model.named_sublayers(): - if name in ('_loss_fn', 'shared_layers'): + if name in ("_loss_fn", "shared_layers"): continue if self.rank == 0: - assert tuple(sublayer.weight.shape) == weight1_np.T.shape, \ - f"Shape does not match, expect: {weight1_np.T.shape}, " \ + assert tuple(sublayer.weight.shape) == weight1_np.T.shape, ( + f"Shape does not match, expect: {weight1_np.T.shape}, " f"actual: {tuple(sublayer.weight.shape)}" + ) elif self.rank == 1: - assert tuple(sublayer.weight.shape) == weight2_np.T.shape, \ - f"Shape does not match, expect: {weight2_np.T.shape}, " \ + assert tuple(sublayer.weight.shape) == weight2_np.T.shape, ( + f"Shape does not match, expect: {weight2_np.T.shape}, " f"actual: {tuple(sublayer.weight.shape)}" + ) standalone_model = StandaloneModel( self.in_features, @@ -182,8 +190,9 @@ class TestLinearPipelineParallel(unittest.TestCase): ) optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.1, - parameters=standalone_model.parameters()) + optimizer_pd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=standalone_model.parameters() + ) pipe_model = fleet.distributed_model(pipe_model) optimizer_te = fleet.distributed_optimizer(optimizer_te) @@ -196,8 +205,9 @@ class TestLinearPipelineParallel(unittest.TestCase): return loss for i in range(self.iter): - inp = paddle.to_tensor(np.random.normal(size=[self.batch_size, self.in_features]), - dtype=self.global_dtype) + inp = paddle.to_tensor( + np.random.normal(size=[self.batch_size, self.in_features]), dtype=self.global_dtype + ) label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1])) loss_te = pipe_model.train_batch([inp, label], optimizer_te) loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd) @@ -214,12 +224,12 @@ class TestLinearPipelineParallelFP8(TestLinearPipelineParallel): self.micro_batch_size = 16 self.in_features = 32 self.hidden_features = 64 - self.global_dtype = 'float32' + self.global_dtype = "float32" self.rtol = 5e-2 self.atol = 5e-2 self.iter = 10 self.fp8 = True -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/paddle/parallel_tests/linear_tp.py b/tests/paddle/parallel_tests/linear_tp.py index ebdabd06bc0c73f5007a5a789fbcfbd318fc443d..4a49474a37f34589a6187d4032635b86fdad85b8 100644 --- a/tests/paddle/parallel_tests/linear_tp.py +++ b/tests/paddle/parallel_tests/linear_tp.py @@ -42,21 +42,21 @@ class TestLinearTp(unittest.TestCase): self.batch_size = 16 self.in_features = 32 self.out_features = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-3 self.atol = 1e-3 self.fp8 = False self.sequence_parallel = False - def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False): + def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ['none', 'column', 'row'] - if split_input == 'column': + assert split_input in ["none", "column", "row"] + if split_input == "column": split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)] - elif split_input == 'row': + input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] + elif split_input == "row": split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] + input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] else: input_parallel = inp input_parallel.stop_gradient = False @@ -69,12 +69,12 @@ class TestLinearTp(unittest.TestCase): loss.backward() optimizer.step() optimizer.clear_grad() - if split_input != 'none': + if split_input != "none": grad_input = [] paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == 'column': + if split_input == "column": grad_input = paddle.concat(grad_input, axis=1) - elif split_input == 'row': + elif split_input == "row": grad_input = paddle.concat(grad_input, axis=0) else: grad_input = input_parallel.grad @@ -86,13 +86,13 @@ class TestLinearTp(unittest.TestCase): layer_te = te.Linear( self.in_features, self.out_features, - parallel_mode='column', + parallel_mode="column", sequence_parallel=self.sequence_parallel, ) layer_pd = te.Linear( self.in_features, self.out_features, - backend='paddle', + backend="paddle", ) # Get total weight total_weight = [] @@ -101,8 +101,9 @@ class TestLinearTp(unittest.TestCase): total_weight = paddle.concat(total_weight, axis=0) layer_pd.weight.copy_(total_weight.T, True) - assert_shape(layer_te.weight, - [self.out_features // self.model_parallel_size, self.in_features]) + assert_shape( + layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features] + ) assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) @@ -118,8 +119,9 @@ class TestLinearTp(unittest.TestCase): layer_te, inp, optimizer_te, - split_input='row' if self.sequence_parallel else 'none', - gather_output=True) + split_input="row" if self.sequence_parallel else "none", + gather_output=True, + ) loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) @@ -130,13 +132,13 @@ class TestLinearTp(unittest.TestCase): layer_te = te.Linear( self.in_features, self.out_features, - parallel_mode='row', + parallel_mode="row", sequence_parallel=self.sequence_parallel, ) layer_pd = te.Linear( self.in_features, self.out_features, - backend='paddle', + backend="paddle", ) # Get total weight total_weight = [] @@ -145,8 +147,9 @@ class TestLinearTp(unittest.TestCase): total_weight = paddle.concat(total_weight, axis=1) layer_pd.weight.copy_(total_weight.T, True) - assert_shape(layer_te.weight, - [self.out_features, self.in_features // self.model_parallel_size]) + assert_shape( + layer_te.weight, [self.out_features, self.in_features // self.model_parallel_size] + ) assert_shape(layer_te.bias, [self.out_features]) optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) @@ -158,11 +161,13 @@ class TestLinearTp(unittest.TestCase): for _ in range(5): inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step(layer_te, - inp, - optimizer_te, - split_input='column', - gather_output=self.sequence_parallel) + loss_tp, grad_input = self._train_one_step( + layer_te, + inp, + optimizer_te, + split_input="column", + gather_output=self.sequence_parallel, + ) loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) @@ -176,7 +181,7 @@ class TestLinearTpFP8(TestLinearTp): self.batch_size = 16 self.in_features = 32 self.out_features = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-2 self.atol = 1e-2 self.fp8 = True @@ -191,7 +196,7 @@ class TestLinearSp(TestLinearTp): self.batch_size = 16 self.in_features = 32 self.out_features = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-3 self.atol = 1e-3 self.fp8 = False @@ -206,12 +211,12 @@ class TestLinearSpFP8(TestLinearTp): self.batch_size = 16 self.in_features = 32 self.out_features = 64 - self.global_dtype = 'bfloat16' + self.global_dtype = "bfloat16" self.rtol = 1e-2 self.atol = 1e-2 self.fp8 = True self.sequence_parallel = True -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/paddle/parallel_tests/transformer_tp.py b/tests/paddle/parallel_tests/transformer_tp.py index 2853eda4d7277743b419ee6ec99bc73fd233afd6..5506be042ff11e09f7706bd13879faa898ac5c0b 100644 --- a/tests/paddle/parallel_tests/transformer_tp.py +++ b/tests/paddle/parallel_tests/transformer_tp.py @@ -45,9 +45,9 @@ class TestTransformerTp(unittest.TestCase): self.ffn_hidden_size = 4096 self.q_seqlen = 128 self.kv_seqlen = 128 - self.mask_type = 'padding' - self.layer_type = 'encoder' - self.global_dtype = 'bfloat16' + self.mask_type = "padding" + self.layer_type = "encoder" + self.global_dtype = "bfloat16" self.rtol = 5e-2 self.atol = 5e-2 self.eps = 1e-3 @@ -58,7 +58,7 @@ class TestTransformerTp(unittest.TestCase): inp, mask = inp_list if sequence_parallel: split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] + input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] else: input_parallel = inp with te.fp8_autocast(enabled=fp8_enabled): @@ -83,16 +83,18 @@ class TestTransformerTp(unittest.TestCase): self.num_heads, ] common_kwargs = { - 'layernorm_epsilon': self.eps, - 'hidden_dropout': 0.0, - 'attention_dropout': 0.0, - 'self_attn_mask_type': self.mask_type, - 'layer_type': self.layer_type, + "layernorm_epsilon": self.eps, + "hidden_dropout": 0.0, + "attention_dropout": 0.0, + "self_attn_mask_type": self.mask_type, + "layer_type": self.layer_type, } - layer_tp = te.TransformerLayer(*common_args, - **common_kwargs, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel) + layer_tp = te.TransformerLayer( + *common_args, + **common_kwargs, + set_parallel_mode=True, + sequence_parallel=self.sequence_parallel, + ) layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False) def _get_total_weight(local_weight, tp_group, axis, interleave=False): @@ -103,12 +105,15 @@ class TestTransformerTp(unittest.TestCase): # Due to the interleaved qkv layout, need to concat on num_head # dimension for column parallel linear in MultiHeadAttention layer assert axis == 0 - assert [3 * self.hidden_size // self.world_size, - self.hidden_size] == partial_weight.shape + assert [ + 3 * self.hidden_size // self.world_size, + self.hidden_size, + ] == partial_weight.shape local_num_head = self.num_heads // self.world_size for idx, _ in enumerate(total_weight): total_weight[idx] = total_weight[idx].reshape( - [3, local_num_head, -1, self.hidden_size]) + [3, local_num_head, -1, self.hidden_size] + ) total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) else: total_weight = paddle.concat(total_weight, axis=axis) @@ -124,48 +129,56 @@ class TestTransformerTp(unittest.TestCase): weight_dst = _get_weight(layer_dst, weight_names) if partition_mode is None: total_weight = weight_src - elif partition_mode == 'column': - total_weight = _get_total_weight(weight_src, - tp_group=self.tp_group, - axis=0, - interleave=interleave) - elif partition_mode == 'row': + elif partition_mode == "column": + total_weight = _get_total_weight( + weight_src, tp_group=self.tp_group, axis=0, interleave=interleave + ) + elif partition_mode == "row": total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) else: raise ValueError(f"Partition Mode {partition_mode} is not supported.") - assert weight_dst.shape == total_weight.shape, \ - f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." + assert ( + weight_dst.shape == total_weight.shape + ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." weight_dst.copy_(total_weight, True) - copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight']) - copy_weight(layer_tp, - layer_single, - 'column', ['self_attention', 'layernorm_qkv', 'weight'], - interleave=True) - copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight']) - copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight']) - copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight']) - copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight']) + copy_weight(layer_tp, layer_single, None, ["self_attention", "layernorm_qkv", "ln_weight"]) + copy_weight( + layer_tp, + layer_single, + "column", + ["self_attention", "layernorm_qkv", "weight"], + interleave=True, + ) + copy_weight(layer_tp, layer_single, "row", ["self_attention", "proj", "weight"]) + copy_weight(layer_tp, layer_single, None, ["layernorm_mlp", "ln_weight"]) + copy_weight(layer_tp, layer_single, "column", ["layernorm_mlp", "fc1_weight"]) + copy_weight(layer_tp, layer_single, "row", ["layernorm_mlp", "fc2_weight"]) if self.sequence_parallel: register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) - optimizer_single = paddle.optimizer.SGD(learning_rate=0.01, - parameters=layer_single.parameters()) + optimizer_single = paddle.optimizer.SGD( + learning_rate=0.01, parameters=layer_single.parameters() + ) layer_tp = fleet.distributed_model(layer_tp) optimizer_tp = fleet.distributed_optimizer(optimizer_tp) for _ in range(5): - inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size], - self.global_dtype) - mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), - dtype='bool') - loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8, - self.sequence_parallel) - loss_single, out_single = self._train_one_step(layer_single, [inp, mask], - optimizer_single, self.fp8) + inp = paddle.uniform( + [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype + ) + mask = paddle.zeros( + shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool" + ) + loss_tp, out_tp = self._train_one_step( + layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel + ) + loss_single, out_single = self._train_one_step( + layer_single, [inp, mask], optimizer_single, self.fp8 + ) assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) @@ -181,9 +194,9 @@ class TestTransformerTpFp8(TestTransformerTp): self.ffn_hidden_size = 4096 self.q_seqlen = 128 self.kv_seqlen = 128 - self.mask_type = 'padding' - self.layer_type = 'encoder' - self.global_dtype = 'bfloat16' + self.mask_type = "padding" + self.layer_type = "encoder" + self.global_dtype = "bfloat16" self.rtol = 5e-2 self.atol = 0.5 self.eps = 1e-3 @@ -202,9 +215,9 @@ class TestTransformerSp(TestTransformerTp): self.ffn_hidden_size = 4096 self.q_seqlen = 128 self.kv_seqlen = 128 - self.mask_type = 'padding' - self.layer_type = 'encoder' - self.global_dtype = 'bfloat16' + self.mask_type = "padding" + self.layer_type = "encoder" + self.global_dtype = "bfloat16" self.rtol = 5e-2 self.atol = 5e-2 self.eps = 1e-3 @@ -223,9 +236,9 @@ class TestTransformerSpFp8(TestTransformerSp): self.ffn_hidden_size = 4096 self.q_seqlen = 128 self.kv_seqlen = 128 - self.mask_type = 'padding' - self.layer_type = 'encoder' - self.global_dtype = 'bfloat16' + self.mask_type = "padding" + self.layer_type = "encoder" + self.global_dtype = "bfloat16" self.rtol = 5e-2 self.atol = 0.5 self.eps = 1e-3 @@ -233,5 +246,5 @@ class TestTransformerSpFp8(TestTransformerSp): self.sequence_parallel = True -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/paddle/recompute_tests/recompute_transformer_encoder.py b/tests/paddle/recompute_tests/recompute_transformer_encoder.py index 8acf23befd609706b48c52c6148acf16528705f5..56d0c2453558025e81f9e7cb7ba1083be25b18e9 100644 --- a/tests/paddle/recompute_tests/recompute_transformer_encoder.py +++ b/tests/paddle/recompute_tests/recompute_transformer_encoder.py @@ -37,14 +37,17 @@ def main(): enable_recompute = int(sys.argv[1]) use_reentrant = int(sys.argv[2]) - layers = paddle.nn.LayerList([ - te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layer_type='encoder', - ) for _ in range(num_layers) - ]) + layers = paddle.nn.LayerList( + [ + te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_heads, + layer_type="encoder", + ) + for _ in range(num_layers) + ] + ) model = Net(layers) optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters()) @@ -52,7 +55,7 @@ def main(): for _ in range(10): inp = paddle.uniform([batch_size, q_seqlen, hidden_size]) inp.stop_gradient = False - mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype='bool') + mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype="bool") with te.fp8_autocast(enabled=True): out = model(inp, mask, enable_recompute, use_reentrant) loss = out.mean() diff --git a/tests/paddle/test_install.py b/tests/paddle/test_install.py index b6e1e4673f36b6b40934e143d26958bab636962e..686771ec09e8e7a80fc5b22d4373389429d957c9 100644 --- a/tests/paddle/test_install.py +++ b/tests/paddle/test_install.py @@ -8,4 +8,4 @@ def test_import(): """ Test if Paddle extension can be imported normally """ - import transformer_engine.paddle # pylint: disable=unused-import + import transformer_engine.paddle # pylint: disable=unused-import diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index 1aeafe030ace36981a8c8b5f4520a26d83ef5557..6a985d7e86675476266f1466164199d784b02f54 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -26,14 +26,14 @@ def setup(): @pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize('use_fp8', [True, False]) +@pytest.mark.parametrize("use_fp8", [True, False]) def test_checkpoint(use_fp8): """Test checkpoint save / load""" bs = 16 in_features = 16 out_features = 32 file_name = "model.pdparams" - input_tensor = paddle.uniform(shape=(bs, in_features), dtype='float32') + input_tensor = paddle.uniform(shape=(bs, in_features), dtype="float32") model = te.Linear(in_features, out_features) model_loaded = te.Linear(in_features, out_features) # Populate amax_history @@ -91,15 +91,18 @@ class TestLinear: """ @staticmethod - @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU") - @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) - @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize('no_dgrad', [True, False]) - @pytest.mark.parametrize('no_wgrad', [True, False]) - @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) - def test_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, - activation_dtype): + @pytest.mark.skipif( + paddle.device.cuda.get_device_capability() < (8, 0), + reason="BF16 Linear requires Ampere+ GPU", + ) + @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) + @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) + @pytest.mark.parametrize("no_dgrad", [True, False]) + @pytest.mark.parametrize("no_wgrad", [True, False]) + @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) + def test_linear_bf16( + bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype + ): """ Test BF16 Linear """ @@ -112,10 +115,9 @@ class TestLinear: paddle.set_default_dtype(activation_dtype) layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False) - layer_pd = te.Linear(in_features, - out_features, - bias_attr=None if has_bias else False, - backend='paddle') + layer_pd = te.Linear( + in_features, out_features, bias_attr=None if has_bias else False, backend="paddle" + ) layer_pd.weight.copy_(layer_te.weight.T, True) if has_bias: layer_pd.bias.copy_(layer_te.bias, True) @@ -139,15 +141,25 @@ class TestLinear: @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) - @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize('no_dgrad', [True, False]) - @pytest.mark.parametrize('no_wgrad', [True, False]) - @pytest.mark.parametrize('fp8_wgrad', [True, False]) - @pytest.mark.parametrize('do_calibration', [True, False]) - @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) - def test_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, - fp8_wgrad, do_calibration, activation_dtype): + @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) + @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) + @pytest.mark.parametrize("no_dgrad", [True, False]) + @pytest.mark.parametrize("no_wgrad", [True, False]) + @pytest.mark.parametrize("fp8_wgrad", [True, False]) + @pytest.mark.parametrize("do_calibration", [True, False]) + @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) + def test_linear_fp8( + bs, + in_features, + out_features, + has_bias, + no_dbias, + no_dgrad, + no_wgrad, + fp8_wgrad, + do_calibration, + activation_dtype, + ): """ Test FP8 Linear """ @@ -170,7 +182,7 @@ class TestLinear: in_features=in_features, out_features=out_features, bias_attr=None if has_bias else False, - backend='paddle', + backend="paddle", ) layer_pd.weight.copy_(layer_te.weight.T, True) if has_bias: @@ -182,8 +194,9 @@ class TestLinear: layer_te.bias.stop_gradient = no_dbias layer_pd.bias.stop_gradient = no_dbias - with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration, - fp8_recipe=recipe): + with fp8_autocast( + enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe + ): out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out) out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out) @@ -199,9 +212,9 @@ class TestLinear: @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) - @pytest.mark.parametrize('activation_dtype', ['bfloat16']) - @pytest.mark.parametrize('num_microbatch', [8]) + @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) + @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) + @pytest.mark.parametrize("num_microbatch", [8]) def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch): """ Test FP8 Linear @@ -236,17 +249,16 @@ class TestLinear: out_ref.backward(grad_out) assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(layer_cached.weight.grad, - layer_normal.weight.grad, - rtol=rtol, - atol=atol) + assert_allclose( + layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol + ) -@pytest.mark.parametrize('bs,hidden_size', NORM_CASES) -@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) -@pytest.mark.parametrize('no_dgrad', [True, False]) -@pytest.mark.parametrize('no_wgrad', [True, False]) -@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) +@pytest.mark.parametrize("bs,hidden_size", NORM_CASES) +@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) +@pytest.mark.parametrize("no_dgrad", [True, False]) +@pytest.mark.parametrize("no_wgrad", [True, False]) +@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype): """ Test BF16 LayerNorm @@ -261,10 +273,9 @@ def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, paddle.set_default_dtype(activation_dtype) layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False) - layer_pd = te.LayerNorm(hidden_size=hidden_size, - eps=eps, - bias_attr=None if has_bias else False, - backend='paddle') + layer_pd = te.LayerNorm( + hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False, backend="paddle" + ) layer_pd.weight.copy_(layer_te.weight, True) if has_bias: layer_pd.bias.copy_(layer_te.bias, True) @@ -293,17 +304,29 @@ class TestLayerNormLinear: """ @staticmethod - @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU") - @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) - @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize('no_dgrad', [True, False]) - @pytest.mark.parametrize('no_wgrad', [True, False]) - @pytest.mark.parametrize('return_ln_out', [True, False]) - @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) - @pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) - def test_layernorm_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, - no_wgrad, return_ln_out, activation_dtype, normalization): + @pytest.mark.skipif( + paddle.device.cuda.get_device_capability() < (8, 0), + reason="BF16 Linear requires Ampere+ GPU", + ) + @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) + @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) + @pytest.mark.parametrize("no_dgrad", [True, False]) + @pytest.mark.parametrize("no_wgrad", [True, False]) + @pytest.mark.parametrize("return_ln_out", [True, False]) + @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) + @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) + def test_layernorm_linear_bf16( + bs, + in_features, + out_features, + has_bias, + no_dbias, + no_dgrad, + no_wgrad, + return_ln_out, + activation_dtype, + normalization, + ): """ Test BF16 LayerNormLinear Layer """ @@ -315,7 +338,7 @@ class TestLayerNormLinear: input_tensor.stop_gradient = no_dgrad grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) eps = 1e-3 - has_ln_bias = normalization == 'LayerNorm' + has_ln_bias = normalization == "LayerNorm" layer_te = te.LayerNormLinear( in_features=in_features, @@ -333,7 +356,7 @@ class TestLayerNormLinear: normalization=normalization, bias_attr=None if has_bias else False, return_layernorm_output=return_ln_out, - backend='paddle', + backend="paddle", ) layer_pd.ln_weight.copy_(layer_te.ln_weight, True) @@ -355,11 +378,11 @@ class TestLayerNormLinear: layer_pd.bias.stop_gradient = no_dbias out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out) - out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te, - input_tensor, - grad_out, - return_ln_out=return_ln_out) + layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out + ) + out, ln_out, grad_input = calc_output_and_grad_ln_out( + layer_te, input_tensor, grad_out, return_ln_out=return_ln_out + ) assert_allclose(out, out_ref, rtol=rtol, atol=atol) if not no_dgrad: @@ -377,18 +400,29 @@ class TestLayerNormLinear: @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) - @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize('no_dgrad', [True, False]) - @pytest.mark.parametrize('no_wgrad', [True, False]) - @pytest.mark.parametrize('fp8_wgrad', [True, False]) - @pytest.mark.parametrize('do_calibration', [True, False]) - @pytest.mark.parametrize('return_ln_out', [True, False]) - @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) - @pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) - def test_layernorm_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, - no_wgrad, fp8_wgrad, do_calibration, return_ln_out, - activation_dtype, normalization): + @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) + @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) + @pytest.mark.parametrize("no_dgrad", [True, False]) + @pytest.mark.parametrize("no_wgrad", [True, False]) + @pytest.mark.parametrize("fp8_wgrad", [True, False]) + @pytest.mark.parametrize("do_calibration", [True, False]) + @pytest.mark.parametrize("return_ln_out", [True, False]) + @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) + @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) + def test_layernorm_linear_fp8( + bs, + in_features, + out_features, + has_bias, + no_dbias, + no_dgrad, + no_wgrad, + fp8_wgrad, + do_calibration, + return_ln_out, + activation_dtype, + normalization, + ): """ Test FP8 LayerNormLinear Layer """ @@ -400,7 +434,7 @@ class TestLayerNormLinear: input_tensor.stop_gradient = no_dgrad grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) eps = 1e-3 - has_ln_bias = normalization == 'LayerNorm' + has_ln_bias = normalization == "LayerNorm" recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) @@ -420,7 +454,7 @@ class TestLayerNormLinear: normalization=normalization, bias_attr=None if has_bias else False, return_layernorm_output=return_ln_out, - backend='paddle', + backend="paddle", ) layer_pd.ln_weight.copy_(layer_te.ln_weight, True) @@ -441,14 +475,15 @@ class TestLayerNormLinear: layer_te.bias.stop_gradient = no_dbias layer_pd.bias.stop_gradient = no_dbias - with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration, - fp8_recipe=recipe): + with fp8_autocast( + enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe + ): out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out) - out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te, - input_tensor, - grad_out, - return_ln_out=return_ln_out) + layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out + ) + out, ln_out, grad_input = calc_output_and_grad_ln_out( + layer_te, input_tensor, grad_out, return_ln_out=return_ln_out + ) assert_allclose(out, out_ref, rtol=rtol, atol=atol) if not no_dgrad: @@ -468,11 +503,12 @@ class TestLayerNormLinear: @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) - @pytest.mark.parametrize('activation_dtype', ['bfloat16']) - @pytest.mark.parametrize('num_microbatch', [8]) - def test_layernorm_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, - num_microbatch): + @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) + @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) + @pytest.mark.parametrize("num_microbatch", [8]) + def test_layernorm_linear_fp8_microbatch( + bs, in_features, out_features, activation_dtype, num_microbatch + ): """ Test FP8 LayerNormLinear Layer """ @@ -513,14 +549,12 @@ class TestLayerNormLinear: out_ref.backward(grad_out) assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(layer_cached.weight.grad, - layer_normal.weight.grad, - rtol=rtol, - atol=atol) - assert_allclose(layer_cached.ln_weight.grad, - layer_normal.ln_weight.grad, - rtol=rtol, - atol=atol) + assert_allclose( + layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol + ) + assert_allclose( + layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol + ) class TestLayerNormMLP: @@ -529,19 +563,31 @@ class TestLayerNormMLP: """ @staticmethod - @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU") - @pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES) - @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize('no_dgrad', [True, False]) - @pytest.mark.parametrize('no_wgrad', [True, False]) - @pytest.mark.parametrize('return_ln_out', [True, False]) - @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) - @pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) - @pytest.mark.parametrize('activation', ['gelu', 'swiglu']) - def test_layernorm_mlp_bf16(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad, - no_wgrad, return_ln_out, activation_dtype, normalization, - activation): + @pytest.mark.skipif( + paddle.device.cuda.get_device_capability() < (8, 0), + reason="BF16 Linear requires Ampere+ GPU", + ) + @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) + @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) + @pytest.mark.parametrize("no_dgrad", [True, False]) + @pytest.mark.parametrize("no_wgrad", [True, False]) + @pytest.mark.parametrize("return_ln_out", [True, False]) + @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) + @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) + @pytest.mark.parametrize("activation", ["gelu", "swiglu"]) + def test_layernorm_mlp_bf16( + bs, + hidden_size, + ffn_hidden_size, + has_bias, + no_dbias, + no_dgrad, + no_wgrad, + return_ln_out, + activation_dtype, + normalization, + activation, + ): """ Tests for TestLayerNormMLP layer """ @@ -553,7 +599,7 @@ class TestLayerNormMLP: input_tensor.stop_gradient = no_dgrad grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) eps = 1e-3 - has_ln_bias = normalization == 'LayerNorm' + has_ln_bias = normalization == "LayerNorm" layer_te = te.LayerNormMLP( hidden_size=hidden_size, @@ -572,7 +618,7 @@ class TestLayerNormMLP: activation=activation, bias_attr=None if has_bias else False, return_layernorm_output=return_ln_out, - backend='paddle', + backend="paddle", ) layer_pd.ln_weight.copy_(layer_te.ln_weight, True) if has_ln_bias: @@ -599,55 +645,63 @@ class TestLayerNormMLP: layer_pd.fc2_bias.stop_gradient = no_dbias out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out) - out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te, - input_tensor, - grad_out, - return_ln_out=return_ln_out) + layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out + ) + out, ln_out, grad_input = calc_output_and_grad_ln_out( + layer_te, input_tensor, grad_out, return_ln_out=return_ln_out + ) assert_allclose(out, out_ref, rtol=rtol, atol=atol) if not no_dgrad: assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) if not no_wgrad: assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - assert_allclose(layer_te.fc1_weight.grad, - layer_pd.fc1_weight.grad.T, - rtol=rtol, - atol=atol) - assert_allclose(layer_te.fc2_weight.grad, - layer_pd.fc2_weight.grad.T, - rtol=rtol, - atol=atol) + assert_allclose( + layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol + ) + assert_allclose( + layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol + ) if not no_dbias: if has_ln_bias: assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) if has_bias: - assert_allclose(layer_te.fc1_bias.grad, - layer_pd.fc1_bias.grad, - rtol=rtol, - atol=atol) - assert_allclose(layer_te.fc2_bias.grad, - layer_pd.fc2_bias.grad, - rtol=rtol, - atol=atol) + assert_allclose( + layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol + ) + assert_allclose( + layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol + ) if return_ln_out: assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES) - @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize('no_dgrad', [True, False]) - @pytest.mark.parametrize('no_wgrad', [True, False]) - @pytest.mark.parametrize('fp8_wgrad', [True, False]) - @pytest.mark.parametrize('do_calibration', [True, False]) - @pytest.mark.parametrize('return_ln_out', [True, False]) - @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) - @pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) - @pytest.mark.parametrize('activation', ['gelu', 'swiglu']) - def test_layernorm_mlp_fp8(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad, - no_wgrad, fp8_wgrad, do_calibration, return_ln_out, activation_dtype, - normalization, activation): + @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) + @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) + @pytest.mark.parametrize("no_dgrad", [True, False]) + @pytest.mark.parametrize("no_wgrad", [True, False]) + @pytest.mark.parametrize("fp8_wgrad", [True, False]) + @pytest.mark.parametrize("do_calibration", [True, False]) + @pytest.mark.parametrize("return_ln_out", [True, False]) + @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) + @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) + @pytest.mark.parametrize("activation", ["gelu", "swiglu"]) + def test_layernorm_mlp_fp8( + bs, + hidden_size, + ffn_hidden_size, + has_bias, + no_dbias, + no_dgrad, + no_wgrad, + fp8_wgrad, + do_calibration, + return_ln_out, + activation_dtype, + normalization, + activation, + ): """ Test FP8 LayerNormMLP Layer """ @@ -659,7 +713,7 @@ class TestLayerNormMLP: input_tensor.stop_gradient = no_dgrad grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) eps = 1e-3 - has_ln_bias = normalization == 'LayerNorm' + has_ln_bias = normalization == "LayerNorm" recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) @@ -681,7 +735,7 @@ class TestLayerNormMLP: activation=activation, bias_attr=None if has_bias else False, return_layernorm_output=return_ln_out, - backend='paddle', + backend="paddle", ) layer_pd.ln_weight.copy_(layer_te.ln_weight, True) if has_ln_bias: @@ -707,40 +761,37 @@ class TestLayerNormMLP: layer_pd.fc1_bias.stop_gradient = no_dbias layer_pd.fc2_bias.stop_gradient = no_dbias - with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration, - fp8_recipe=recipe): + with fp8_autocast( + enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe + ): out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out) - out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te, - input_tensor, - grad_out, - return_ln_out=return_ln_out) + layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out + ) + out, ln_out, grad_input = calc_output_and_grad_ln_out( + layer_te, input_tensor, grad_out, return_ln_out=return_ln_out + ) assert_allclose(out, out_ref, rtol=rtol, atol=atol) if not no_dgrad: assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) if not no_wgrad: assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - assert_allclose(layer_te.fc1_weight.grad, - layer_pd.fc1_weight.grad.T, - rtol=rtol, - atol=atol) - assert_allclose(layer_te.fc2_weight.grad, - layer_pd.fc2_weight.grad.T, - rtol=rtol, - atol=atol) + assert_allclose( + layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol + ) + assert_allclose( + layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol + ) if not no_dbias: if has_ln_bias: assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) if has_bias: - assert_allclose(layer_te.fc1_bias.grad, - layer_pd.fc1_bias.grad, - rtol=rtol, - atol=atol) - assert_allclose(layer_te.fc2_bias.grad, - layer_pd.fc2_bias.grad, - rtol=rtol, - atol=atol) + assert_allclose( + layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol + ) + assert_allclose( + layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol + ) if return_ln_out: assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) @@ -749,11 +800,12 @@ class TestLayerNormMLP: @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES) - @pytest.mark.parametrize('activation_dtype', ['bfloat16']) - @pytest.mark.parametrize('num_microbatch', [8]) - def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activation_dtype, - num_microbatch): + @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) + @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) + @pytest.mark.parametrize("num_microbatch", [8]) + def test_layernorm_mlp_fp8_microbatch( + bs, hidden_size, ffn_hidden_size, activation_dtype, num_microbatch + ): """ Test FP8 LayerNormMLP Layer """ @@ -803,28 +855,26 @@ class TestLayerNormMLP: out_ref.backward(grad_out) assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(layer_cached.ln_weight.grad, - layer_normal.ln_weight.grad, - rtol=rtol, - atol=atol) - assert_allclose(layer_cached.fc1_weight.grad, - layer_normal.fc1_weight.grad, - rtol=rtol, - atol=atol) - assert_allclose(layer_cached.fc2_weight.grad, - layer_normal.fc2_weight.grad, - rtol=rtol, - atol=atol) - - -@pytest.mark.parametrize('bs', [1, 2]) -@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16]]) -@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]]) -@pytest.mark.parametrize('attn_type', ['self', 'cross']) -@pytest.mark.parametrize('mask_type', ['causal', 'padding']) -@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) -def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, - mask_type, math_dtype): + assert_allclose( + layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol + ) + assert_allclose( + layer_cached.fc1_weight.grad, layer_normal.fc1_weight.grad, rtol=rtol, atol=atol + ) + assert_allclose( + layer_cached.fc2_weight.grad, layer_normal.fc2_weight.grad, rtol=rtol, atol=atol + ) + + +@pytest.mark.parametrize("bs", [1, 2]) +@pytest.mark.parametrize("hidden_size, num_heads", [[1024, 16]]) +@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) +@pytest.mark.parametrize("attn_type", ["self", "cross"]) +@pytest.mark.parametrize("mask_type", ["causal", "padding"]) +@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) +def test_dot_product_attention( + bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype +): """ Test DotProductAttention Layer """ @@ -835,53 +885,64 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, # Skip if cuDNN fused attention is not supported if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_heads, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=head_size, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, + num_heads=num_heads, + num_gqa_groups=num_heads, + q_seqlen=q_seqlen, + kv_seqlen=kv_seqlen, + head_size=head_size, + dtype=math_dtype, + dropout=0.0, + qkv_layout="bshd_bshd_bshd", + bias_type="no_bias", + mask_type=mask_type, ): pytest.skip("cuDNN fused attention is not supported") - attn_q_input = paddle.normal(mean=0.0, std=0.02, - shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype) - attn_k_input = paddle.normal(mean=0.0, std=0.02, - shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype) - attn_v_input = paddle.normal(mean=0.0, std=0.02, - shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype) - - q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32') - kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,), - dtype='int32') if attn_type == 'cross' else q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') - - grad_out = paddle.normal(mean=0.0, std=0.02, - shape=(bs, q_seqlen, num_heads, head_size)).astype('float32') + attn_q_input = paddle.normal( + mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size) + ).astype(math_dtype) + attn_k_input = paddle.normal( + mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size) + ).astype(math_dtype) + attn_v_input = paddle.normal( + mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size) + ).astype(math_dtype) + + q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype="int32") + kv_actual_seqlen = ( + paddle.randint(low=20, high=kv_seqlen, shape=(bs,), dtype="int32") + if attn_type == "cross" + else q_actual_seqlen + ) + attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") + + grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype( + "float32" + ) for i in range(0, bs): - grad_out[i, q_actual_seqlen[i]:, :, :] = 0 + grad_out[i, q_actual_seqlen[i] :, :, :] = 0 grad_out = grad_out.astype(math_dtype) for i in range(0, bs): - attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False + attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False head_size = hidden_size // num_heads - layer_te = te.DotProductAttention(num_heads, - head_size, - attention_dropout=0.0, - attn_mask_type=mask_type, - attention_type=attn_type, - backend='transformer_engine') - layer_pd = te.DotProductAttention(num_heads, - head_size, - attention_dropout=0.0, - attn_mask_type=mask_type, - attention_type=attn_type, - backend='paddle') + layer_te = te.DotProductAttention( + num_heads, + head_size, + attention_dropout=0.0, + attn_mask_type=mask_type, + attention_type=attn_type, + backend="transformer_engine", + ) + layer_pd = te.DotProductAttention( + num_heads, + head_size, + attention_dropout=0.0, + attn_mask_type=mask_type, + attention_type=attn_type, + backend="paddle", + ) def calc_attn_output_and_grad(layer, q, k, v, mask, dout): _q = paddle.to_tensor(q, stop_gradient=False) @@ -892,23 +953,29 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, out.backward(dout) return out, _q.grad, _k.grad, _v.grad - out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input, - attn_v_input, attn_mask, grad_out) + out, q_grad, k_grad, v_grad = calc_attn_output_and_grad( + layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out + ) out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad( - layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out) + layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out + ) valid_out_ref = paddle.full_like(out_ref, 0) for i in range(0, bs): - valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :] + valid_out_ref[i, 0 : q_actual_seqlen[i], :, :] = out_ref[i, 0 : q_actual_seqlen[i], :, :] valid_q_grad_ref = paddle.full_like(q_grad_ref, 0) valid_k_grad_ref = paddle.full_like(k_grad_ref, 0) valid_v_grad_ref = paddle.full_like(v_grad_ref, 0) for i in range(0, bs): - valid_q_grad_ref[i, 0:q_actual_seqlen[i], :, :] = q_grad_ref[i, 0:q_actual_seqlen[i], :, :] - valid_k_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = k_grad_ref[i, - 0:kv_actual_seqlen[i], :, :] - valid_v_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = v_grad_ref[i, - 0:kv_actual_seqlen[i], :, :] + valid_q_grad_ref[i, 0 : q_actual_seqlen[i], :, :] = q_grad_ref[ + i, 0 : q_actual_seqlen[i], :, : + ] + valid_k_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = k_grad_ref[ + i, 0 : kv_actual_seqlen[i], :, : + ] + valid_v_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = v_grad_ref[ + i, 0 : kv_actual_seqlen[i], :, : + ] assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol) assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol) @@ -916,21 +983,34 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize('bs', [1, 2]) -@pytest.mark.parametrize('num_gqa_groups', [1, 2, 4]) -@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]]) -@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]]) -@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) -@pytest.mark.parametrize('no_wgrad', [True, False]) -@pytest.mark.parametrize('mask_type', ['causal', 'padding']) -@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) -@pytest.mark.parametrize('output_layernorm', [True, False]) -@pytest.mark.parametrize('return_layernorm_output', [True, False]) -@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) -def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size, - has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type, - math_dtype, output_layernorm, return_layernorm_output, - normalization): +@pytest.mark.parametrize("bs", [1, 2]) +@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4]) +@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]]) +@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) +@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]]) +@pytest.mark.parametrize("no_wgrad", [True, False]) +@pytest.mark.parametrize("mask_type", ["causal", "padding"]) +@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) +@pytest.mark.parametrize("output_layernorm", [True, False]) +@pytest.mark.parametrize("return_layernorm_output", [True, False]) +@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) +def test_transformer_encoder_layer( + bs, + hidden_size, + num_heads, + num_gqa_groups, + ffn_hidden_size, + has_bias, + no_dbias, + no_wgrad, + q_seqlen, + kv_seqlen, + mask_type, + math_dtype, + output_layernorm, + return_layernorm_output, + normalization, +): """ Test Transformer Encoder Layer """ @@ -938,68 +1018,73 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f rtol = 5e-2 atol = 5e-2 eps = 1e-3 - has_ln_bias = normalization == 'LayerNorm' + has_ln_bias = normalization == "LayerNorm" # Skip if cuDNN fused attention is not supported if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, + num_heads=num_heads, + num_gqa_groups=num_gqa_groups, + q_seqlen=q_seqlen, + kv_seqlen=kv_seqlen, + head_size=hidden_size // num_heads, + dtype=math_dtype, + dropout=0.0, + qkv_layout="bshd_bshd_bshd", + bias_type="no_bias", + mask_type=mask_type, ): pytest.skip("cuDNN fused attention is not supported") encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) - q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen + q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') + attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - grad_out = paddle.normal(mean=0.0, std=0.02, - shape=(bs, q_seqlen, hidden_size)).astype('float32') + grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype( + "float32" + ) for i in range(0, bs): - grad_out[i, q_actual_seqlen[i]:, :] = 0 + grad_out[i, q_actual_seqlen[i] :, :] = 0 grad_out = grad_out.astype(math_dtype) for i in range(0, bs): - attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False - - layer_te = te.TransformerLayer(hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type='encoder', - normalization=normalization, - backend='transformer_engine') - layer_pd = te.TransformerLayer(hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type='encoder', - normalization=normalization, - backend='paddle') + attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False + + layer_te = te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_heads, + num_gqa_groups=num_gqa_groups, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None if has_bias else False, + self_attn_mask_type=mask_type, + apply_residual_connection_post_layernorm=return_layernorm_output, + output_layernorm=output_layernorm, + layer_type="encoder", + normalization=normalization, + backend="transformer_engine", + ) + layer_pd = te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_heads, + num_gqa_groups=num_gqa_groups, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None if has_bias else False, + self_attn_mask_type=mask_type, + apply_residual_connection_post_layernorm=return_layernorm_output, + output_layernorm=output_layernorm, + layer_type="encoder", + normalization=normalization, + backend="paddle", + ) # MultiHeadAttention params if output_layernorm: @@ -1012,21 +1097,25 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f layer_te.self_attention.qkv.bias.stop_gradient = no_dbias else: layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( - layer_te.self_attention.layernorm_qkv.ln_weight, True) + layer_te.self_attention.layernorm_qkv.ln_weight, True + ) layer_pd.self_attention.layernorm_qkv.weight.copy_( - layer_te.self_attention.layernorm_qkv.weight.T, True) + layer_te.self_attention.layernorm_qkv.weight.T, True + ) layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad if has_ln_bias: layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( - layer_te.self_attention.layernorm_qkv.ln_bias, True) + layer_te.self_attention.layernorm_qkv.ln_bias, True + ) layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias if has_bias: layer_pd.self_attention.layernorm_qkv.bias.copy_( - layer_te.self_attention.layernorm_qkv.bias, True) + layer_te.self_attention.layernorm_qkv.bias, True + ) layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias @@ -1074,52 +1163,75 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f out.backward(dout) return out, _encoder_input.grad - out_ref, grad_input_ref = calc_transformer_output_and_grad(layer_pd, encoder_input, attn_mask, - grad_out) + out_ref, grad_input_ref = calc_transformer_output_and_grad( + layer_pd, encoder_input, attn_mask, grad_out + ) out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out) assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) if not no_wgrad: if output_layernorm: - assert_allclose(layer_te.self_attention.qkv.weight.grad, - layer_pd.self_attention.qkv.weight.grad.T, - rtol=rtol, - atol=atol) + assert_allclose( + layer_te.self_attention.qkv.weight.grad, + layer_pd.self_attention.qkv.weight.grad.T, + rtol=rtol, + atol=atol, + ) else: - assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad, - layer_pd.self_attention.layernorm_qkv.weight.grad.T, - rtol=rtol, - atol=atol) + assert_allclose( + layer_te.self_attention.layernorm_qkv.weight.grad, + layer_pd.self_attention.layernorm_qkv.weight.grad.T, + rtol=rtol, + atol=atol, + ) if not no_dbias: if output_layernorm: - assert_allclose(layer_te.self_attention.qkv.bias.grad, - layer_pd.self_attention.qkv.bias.grad, - rtol=0.01, - atol=0.5) + assert_allclose( + layer_te.self_attention.qkv.bias.grad, + layer_pd.self_attention.qkv.bias.grad, + rtol=0.01, + atol=0.5, + ) else: - assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, - layer_pd.self_attention.layernorm_qkv.bias.grad, - rtol=0.01, - atol=0.5) - - -@pytest.mark.parametrize('bs', [1, 2]) -@pytest.mark.parametrize('num_gqa_groups', [1, 2, 4]) -@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]]) -@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]]) -@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) -@pytest.mark.parametrize('no_wgrad', [True, False]) -@pytest.mark.parametrize('mask_type', ['causal', 'padding']) -@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) -@pytest.mark.parametrize('output_layernorm', [True, False]) -@pytest.mark.parametrize('return_layernorm_output', [True, False]) -@pytest.mark.parametrize('recompute_core_attention', [True, False]) -@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) -def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size, - has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type, - math_dtype, output_layernorm, return_layernorm_output, - recompute_core_attention, normalization): + assert_allclose( + layer_te.self_attention.layernorm_qkv.bias.grad, + layer_pd.self_attention.layernorm_qkv.bias.grad, + rtol=0.01, + atol=0.5, + ) + + +@pytest.mark.parametrize("bs", [1, 2]) +@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4]) +@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]]) +@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) +@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]]) +@pytest.mark.parametrize("no_wgrad", [True, False]) +@pytest.mark.parametrize("mask_type", ["causal", "padding"]) +@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) +@pytest.mark.parametrize("output_layernorm", [True, False]) +@pytest.mark.parametrize("return_layernorm_output", [True, False]) +@pytest.mark.parametrize("recompute_core_attention", [True, False]) +@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) +def test_transformer_decoder_layer( + bs, + hidden_size, + num_heads, + num_gqa_groups, + ffn_hidden_size, + has_bias, + no_dbias, + no_wgrad, + q_seqlen, + kv_seqlen, + mask_type, + math_dtype, + output_layernorm, + return_layernorm_output, + recompute_core_attention, + normalization, +): """ Test Transformer Decoder Layer """ @@ -1127,34 +1239,37 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f rtol = 5e-2 atol = 6e-2 eps = 1e-3 - has_ln_bias = normalization == 'LayerNorm' + has_ln_bias = normalization == "LayerNorm" # Skip if cuDNN fused attention is not supported if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, + num_heads=num_heads, + num_gqa_groups=num_gqa_groups, + q_seqlen=q_seqlen, + kv_seqlen=kv_seqlen, + head_size=hidden_size // num_heads, + dtype=math_dtype, + dropout=0.0, + qkv_layout="bshd_bshd_bshd", + bias_type="no_bias", + mask_type=mask_type, ): pytest.skip("cuDNN fused attention is not supported") - encoder_input = paddle.normal(mean=0.0, std=0.1, - shape=(bs, q_seqlen, hidden_size)).astype(math_dtype) - encoder_output = paddle.normal(mean=0.0, std=0.1, - shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype) + encoder_input = paddle.normal(mean=0.0, std=0.1, shape=(bs, q_seqlen, hidden_size)).astype( + math_dtype + ) + encoder_output = paddle.normal(mean=0.0, std=0.1, shape=(bs, kv_seqlen, hidden_size)).astype( + math_dtype + ) - q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen + q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') + attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - grad_out = paddle.normal(mean=0.0, std=0.01, - shape=(bs, q_seqlen, hidden_size)).astype('float32') + grad_out = paddle.normal(mean=0.0, std=0.01, shape=(bs, q_seqlen, hidden_size)).astype( + "float32" + ) # rounding to avoid numerical issues encoder_input = paddle.round(encoder_input * 1000) / 1000 @@ -1162,42 +1277,46 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f grad_out = paddle.round(grad_out * 1000) / 1000 for i in range(0, bs): - grad_out[i, q_actual_seqlen[i]:, :] = 0 + grad_out[i, q_actual_seqlen[i] :, :] = 0 grad_out = grad_out.astype(math_dtype) for i in range(0, bs): - attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False - - layer_te = te.TransformerLayer(hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type='decoder', - normalization=normalization, - backend='transformer_engine') - layer_pd = te.TransformerLayer(hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type='decoder', - normalization=normalization, - backend='paddle') + attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False + + layer_te = te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_heads, + num_gqa_groups=num_gqa_groups, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None if has_bias else False, + self_attn_mask_type=mask_type, + apply_residual_connection_post_layernorm=return_layernorm_output, + output_layernorm=output_layernorm, + layer_type="decoder", + normalization=normalization, + backend="transformer_engine", + ) + layer_pd = te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_heads, + num_gqa_groups=num_gqa_groups, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None if has_bias else False, + self_attn_mask_type=mask_type, + apply_residual_connection_post_layernorm=return_layernorm_output, + output_layernorm=output_layernorm, + layer_type="decoder", + normalization=normalization, + backend="paddle", + ) # MultiHeadAttention params - self attn if output_layernorm: @@ -1210,21 +1329,25 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f layer_te.self_attention.qkv.bias.stop_gradient = no_dbias else: layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( - layer_te.self_attention.layernorm_qkv.ln_weight, True) + layer_te.self_attention.layernorm_qkv.ln_weight, True + ) layer_pd.self_attention.layernorm_qkv.weight.copy_( - layer_te.self_attention.layernorm_qkv.weight.T, True) + layer_te.self_attention.layernorm_qkv.weight.T, True + ) layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad if has_ln_bias: layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( - layer_te.self_attention.layernorm_qkv.ln_bias, True) + layer_te.self_attention.layernorm_qkv.ln_bias, True + ) layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias if has_bias: layer_pd.self_attention.layernorm_qkv.bias.copy_( - layer_te.self_attention.layernorm_qkv.bias, True) + layer_te.self_attention.layernorm_qkv.bias, True + ) layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias @@ -1238,26 +1361,31 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f # MultiHeadAttention params - cross attn layer_pd.inter_attention.layernorm_query.ln_weight.copy_( - layer_te.inter_attention.layernorm_query.ln_weight, True) + layer_te.inter_attention.layernorm_query.ln_weight, True + ) layer_pd.inter_attention.layernorm_query.weight.copy_( - layer_te.inter_attention.layernorm_query.weight.T, True) + layer_te.inter_attention.layernorm_query.weight.T, True + ) layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad if has_ln_bias: layer_pd.inter_attention.layernorm_query.ln_bias.copy_( - layer_te.inter_attention.layernorm_query.ln_bias, True) + layer_te.inter_attention.layernorm_query.ln_bias, True + ) layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias if has_bias: layer_pd.inter_attention.layernorm_query.bias.copy_( - layer_te.inter_attention.layernorm_query.bias, True) + layer_te.inter_attention.layernorm_query.bias, True + ) layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias - layer_pd.inter_attention.key_value.weight.copy_(layer_te.inter_attention.key_value.weight.T, - True) + layer_pd.inter_attention.key_value.weight.copy_( + layer_te.inter_attention.key_value.weight.T, True + ) layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True) @@ -1301,25 +1429,30 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f layer_te.layernorm.weight.stop_gradient = no_wgrad layer_te.layernorm.bias.stop_gradient = no_dbias - def calc_transformer_output_and_grad(layer, - encoder_input, - mask, - encoder_output, - enc_dec_attn_mask, - dout, - recompute_core_attention=False): + def calc_transformer_output_and_grad( + layer, + encoder_input, + mask, + encoder_output, + enc_dec_attn_mask, + dout, + recompute_core_attention=False, + ): _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) _encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False) - out = layer(_encoder_input, - mask, - _encoder_output, - enc_dec_attn_mask, - recompute_core_attention=recompute_core_attention) + out = layer( + _encoder_input, + mask, + _encoder_output, + enc_dec_attn_mask, + recompute_core_attention=recompute_core_attention, + ) out.backward(dout) return out, _encoder_input.grad, _encoder_output.grad out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad( - layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out) + layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out + ) out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad( layer_te, encoder_input, @@ -1327,52 +1460,74 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f encoder_output, attn_mask, grad_out, - recompute_core_attention=recompute_core_attention) + recompute_core_attention=recompute_core_attention, + ) assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol) assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol) if not no_wgrad: if output_layernorm: - assert_allclose(layer_te.self_attention.qkv.weight.grad, - layer_pd.self_attention.qkv.weight.grad.T, - rtol=rtol, - atol=atol) + assert_allclose( + layer_te.self_attention.qkv.weight.grad, + layer_pd.self_attention.qkv.weight.grad.T, + rtol=rtol, + atol=atol, + ) else: - assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad, - layer_pd.self_attention.layernorm_qkv.weight.grad.T, - rtol=rtol, - atol=atol) - assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad, - layer_pd.inter_attention.layernorm_query.weight.grad.T, - rtol=rtol, - atol=atol) + assert_allclose( + layer_te.self_attention.layernorm_qkv.weight.grad, + layer_pd.self_attention.layernorm_qkv.weight.grad.T, + rtol=rtol, + atol=atol, + ) + assert_allclose( + layer_te.inter_attention.layernorm_query.weight.grad, + layer_pd.inter_attention.layernorm_query.weight.grad.T, + rtol=rtol, + atol=atol, + ) if not no_dbias: if output_layernorm: - assert_allclose(layer_te.self_attention.qkv.bias.grad, - layer_pd.self_attention.qkv.bias.grad, - rtol=0.5, - atol=0.6) + assert_allclose( + layer_te.self_attention.qkv.bias.grad, + layer_pd.self_attention.qkv.bias.grad, + rtol=0.5, + atol=0.6, + ) else: - assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, - layer_pd.self_attention.layernorm_qkv.bias.grad, - rtol=0.01, - atol=0.5) - assert_allclose(layer_te.inter_attention.layernorm_query.bias.grad, - layer_pd.inter_attention.layernorm_query.bias.grad, - rtol=rtol, - atol=atol) + assert_allclose( + layer_te.self_attention.layernorm_qkv.bias.grad, + layer_pd.self_attention.layernorm_qkv.bias.grad, + rtol=0.01, + atol=0.5, + ) + assert_allclose( + layer_te.inter_attention.layernorm_query.bias.grad, + layer_pd.inter_attention.layernorm_query.bias.grad, + rtol=rtol, + atol=atol, + ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize('bs', [8]) -@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) -@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128]]) -@pytest.mark.parametrize('mask_type', ['causal']) -@pytest.mark.parametrize('math_dtype', ['bfloat16']) -@pytest.mark.parametrize('num_microbatch', [8]) -def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hidden_size, q_seqlen, - kv_seqlen, mask_type, math_dtype, num_microbatch): +@pytest.mark.parametrize("bs", [8]) +@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[1024, 16, 4096]]) +@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[128, 128]]) +@pytest.mark.parametrize("mask_type", ["causal"]) +@pytest.mark.parametrize("math_dtype", ["bfloat16"]) +@pytest.mark.parametrize("num_microbatch", [8]) +def test_transformer_encoder_layer_microbatch( + bs, + hidden_size, + num_heads, + ffn_hidden_size, + q_seqlen, + kv_seqlen, + mask_type, + math_dtype, + num_microbatch, +): """ Test Transformer Encoder Layer with FP8 weight caching """ @@ -1383,48 +1538,56 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi # Skip if cuDNN fused attention is not supported if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_heads, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type=mask_type, + num_heads=num_heads, + num_gqa_groups=num_heads, + q_seqlen=q_seqlen, + kv_seqlen=kv_seqlen, + head_size=hidden_size // num_heads, + dtype=math_dtype, + dropout=0.0, + qkv_layout="bs3hd", + bias_type="no_bias", + mask_type=mask_type, ): pytest.skip("cuDNN fused attention is not supported") - layer_cached = te.TransformerLayer(hidden_size, - ffn_hidden_size, - num_heads, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None, - self_attn_mask_type=mask_type, - layer_type='encoder') - layer_normal = te.TransformerLayer(hidden_size, - ffn_hidden_size, - num_heads, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None, - self_attn_mask_type=mask_type, - layer_type='encoder') + layer_cached = te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_heads, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None, + self_attn_mask_type=mask_type, + layer_type="encoder", + ) + layer_normal = te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_heads, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None, + self_attn_mask_type=mask_type, + layer_type="encoder", + ) layer_normal.self_attention.layernorm_qkv.ln_weight.copy_( - layer_cached.self_attention.layernorm_qkv.ln_weight, True) + layer_cached.self_attention.layernorm_qkv.ln_weight, True + ) layer_normal.self_attention.layernorm_qkv.ln_bias.copy_( - layer_cached.self_attention.layernorm_qkv.ln_bias, True) + layer_cached.self_attention.layernorm_qkv.ln_bias, True + ) layer_normal.self_attention.layernorm_qkv.weight.copy_( - layer_cached.self_attention.layernorm_qkv.weight, True) + layer_cached.self_attention.layernorm_qkv.weight, True + ) layer_normal.self_attention.layernorm_qkv.bias.copy_( - layer_cached.self_attention.layernorm_qkv.bias, True) + layer_cached.self_attention.layernorm_qkv.bias, True + ) layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True) layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True) @@ -1442,18 +1605,19 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi def generate_input(): encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) - q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen + q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') + attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - grad_out = paddle.normal(mean=0.0, std=0.02, - shape=(bs, q_seqlen, hidden_size)).astype('float32') + grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype( + "float32" + ) for i in range(0, bs): - grad_out[i, q_actual_seqlen[i]:, :] = 0 + grad_out[i, q_actual_seqlen[i] :, :] = 0 grad_out = grad_out.astype(math_dtype) for i in range(0, bs): - attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False + attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False return encoder_input, attn_mask, grad_out @@ -1477,7 +1641,9 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi out_ref.backward(grad_out) assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(layer_cached.self_attention.layernorm_qkv.weight.grad, - layer_normal.self_attention.layernorm_qkv.weight.grad, - rtol=rtol, - atol=atol) + assert_allclose( + layer_cached.self_attention.layernorm_qkv.weight.grad, + layer_normal.self_attention.layernorm_qkv.weight.grad, + rtol=rtol, + atol=atol, + ) diff --git a/tests/paddle/test_master_grad.py b/tests/paddle/test_master_grad.py index f5f67cba4e9fb647ac85d15938b49fbc70fc829e..4e029cf8ddd9a9fea1a6b03a06ffd2709587d653 100644 --- a/tests/paddle/test_master_grad.py +++ b/tests/paddle/test_master_grad.py @@ -16,7 +16,7 @@ is_fp8_supported, reason = is_fp8_available() def create_optimizer(model, use_pure_bf16, use_main_grad): - '''Create optimizer''' + """Create optimizer""" if use_main_grad: assert use_pure_bf16 model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") @@ -32,7 +32,7 @@ def create_optimizer(model, use_pure_bf16, use_main_grad): class Net(paddle.nn.Layer): - '''Network use for main_grad testing''' + """Network use for main_grad testing""" def __init__(self, fuse_wgrad_accumulation): super().__init__() @@ -40,7 +40,7 @@ class Net(paddle.nn.Layer): 4096, 16384, 32, - layer_type='encoder', + layer_type="encoder", fuse_wgrad_accumulation=fuse_wgrad_accumulation, ) @@ -50,7 +50,7 @@ class Net(paddle.nn.Layer): def train(enable_master_grad, fuse_wgrad_accumulation=False): - '''Train function''' + """Train function""" paddle.seed(10) accumulate_steps = 4 @@ -64,7 +64,7 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False): loss_list = [] for step_id in range(16): - inp = paddle.uniform([2, 1024, 4096], dtype='float32') + inp = paddle.uniform([2, 1024, 4096], dtype="float32") inp.stop_gradient = False with te.fp8_autocast(enabled=True): out = model(inp) @@ -82,8 +82,8 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False): @pytest.mark.skipif(not is_fp8_supported, reason=reason) def test_master_grad(): - '''Test main_grad''' - paddle.set_default_dtype('float32') + """Test main_grad""" + paddle.set_default_dtype("float32") loss1 = train(enable_master_grad=False) loss2 = train(enable_master_grad=True) loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True) diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py index bb6a73eaecf66a0ab1595b4c5506ca5b9cd563ce..b3b856077513d2a39cc05133e4f4049678e5f70f 100644 --- a/tests/paddle/test_operators.py +++ b/tests/paddle/test_operators.py @@ -56,8 +56,13 @@ from transformer_engine.paddle.fp8 import is_fp8_available from transformer_engine.paddle.constants import FP8FwdTensors from transformer_engine.common.recipe import DelayedScaling -GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), - (16384, 1024, 1024)] +GEMM_CASES = [ + (256, 256, 512), + (32, 32, 32), + (16384, 1024, 2816), + (16384, 2816, 1024), + (16384, 1024, 1024), +] is_fp8_supported, reason = is_fp8_available() SELF_ATTN_CASES = [(2, 512, 12, 64)] @@ -74,13 +79,13 @@ def setup(): yield -@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize('inplace', [True, False]) +@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("inplace", [True, False]) def test_quantize_dequantize(fp8_dtype, inplace): """ Test cast_to_fp8 and cast_from_fp8 """ - a = paddle.rand(shape=(32, 32), dtype='float32') + a = paddle.rand(shape=(32, 32), dtype="float32") # Init fp8_meta fp8_meta = create_fp8_meta() a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None @@ -99,7 +104,7 @@ def copy_bits_from_float_to_uint16(f): """ Copy bits """ - return struct.unpack('> 16 + return struct.unpack("> 16 def convert_float_to_uint16(float_list): @@ -124,95 +129,106 @@ class TestTranspose: """ Test BF16 transpose """ - a = paddle.rand(shape=(16, 32), dtype='bfloat16') + a = paddle.rand(shape=(16, 32), dtype="bfloat16") a_transposed = transpose(a, otype=tex.DType.kBFloat16) assert_allclose(a_transposed, a.T) @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) def test_transpose_fp8(fp8_dtype): """ Test FP8 transpose """ min_val = -8 max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32') + a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") fp8_meta = create_fp8_meta() a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype) - a_transposed = cast_from_fp8(a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32) + a_transposed = cast_from_fp8( + a_fp8_transposed, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + itype=fp8_dtype, + otype=tex.DType.kFloat32, + ) assert_allclose(a_transposed, a.T) @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - @pytest.mark.parametrize('inplace', [True, False]) + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) + @pytest.mark.parametrize("inplace", [True, False]) def test_cast_transpose(fp8_dtype, inplace): """ Test cast_transpose """ min_val = -8 max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32') + a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") fp8_meta = create_fp8_meta() a_fp8_casted, a_fp8_transposed = None, None if inplace: a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8) a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8) - a_fp8_casted, a_fp8_transposed = cast_transpose(a, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - otype=fp8_dtype, - cast_out=a_fp8_casted, - transpose_out=a_fp8_transposed) - - a_transposed = cast_from_fp8(a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32) - - a_casted = cast_from_fp8(a_fp8_casted, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32) + a_fp8_casted, a_fp8_transposed = cast_transpose( + a, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + otype=fp8_dtype, + cast_out=a_fp8_casted, + transpose_out=a_fp8_transposed, + ) + + a_transposed = cast_from_fp8( + a_fp8_transposed, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + itype=fp8_dtype, + otype=tex.DType.kFloat32, + ) + + a_casted = cast_from_fp8( + a_fp8_casted, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + itype=fp8_dtype, + otype=tex.DType.kFloat32, + ) assert_allclose(a_casted, a) assert_allclose(a_transposed, a.T) @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) def test_cast_transpose_bgrad(fp8_dtype): """ Test cast_transpose_bgrad """ min_val = -8 max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32') + a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") fp8_meta = create_fp8_meta() - bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(a, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - otype=fp8_dtype) - - a_transposed = cast_from_fp8(a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32) - - a_casted = cast_from_fp8(a_fp8_casted, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32) + bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad( + a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype + ) + + a_transposed = cast_from_fp8( + a_fp8_transposed, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + itype=fp8_dtype, + otype=tex.DType.kFloat32, + ) + + a_casted = cast_from_fp8( + a_fp8_casted, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + itype=fp8_dtype, + otype=tex.DType.kFloat32, + ) assert_allclose(a_casted, a) assert_allclose(a_transposed, a.T) @@ -229,7 +245,7 @@ class TestActivation: """ Test BF16 GELU Forward """ - a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1 + a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 gelu_out = te_gelu(a, otype=tex.DType.kBFloat16) gelu_ref = paddle.nn.GELU()(a) @@ -237,21 +253,23 @@ class TestActivation: @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) def test_gelu_fp8(fp8_dtype): """ Test FP8 GELU Forward """ - a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 + a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 fp8_meta = create_fp8_meta() gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - gelu_out = cast_from_fp8(gelu_out_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32) + gelu_out = cast_from_fp8( + gelu_out_fp8, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + itype=fp8_dtype, + otype=tex.DType.kFloat32, + ) gelu_ref = paddle.nn.GELU()(a) @@ -259,36 +277,38 @@ class TestActivation: @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) def test_gelu_bwd_fp8(fp8_dtype): """ Test FP8 GELU Backward """ # y = GELU(x), calculate ref - x = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 + x = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 x.stop_gradient = False y = paddle.nn.GELU()(x) - y_grad = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 + y_grad = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 paddle.autograd.backward([y], [y_grad], True) # calculate fp8 fp8_meta = create_fp8_meta() - x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(y_grad, - x, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - otype=fp8_dtype) - - x_grad = cast_from_fp8(x_grad_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32) - - x_grad_t = cast_from_fp8(x_grad_t_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32) + x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8( + y_grad, x, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype + ) + + x_grad = cast_from_fp8( + x_grad_fp8, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + itype=fp8_dtype, + otype=tex.DType.kFloat32, + ) + + x_grad_t = cast_from_fp8( + x_grad_t_fp8, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + itype=fp8_dtype, + otype=tex.DType.kFloat32, + ) assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01) assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01) @@ -299,7 +319,7 @@ class TestActivation: """ Test BF16 SwiGLU Forward """ - a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1 + a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 swiglu_out = swiglu(a, otype=tex.DType.kBFloat16) swiglu_ref = swiglu_pd(a) @@ -307,21 +327,23 @@ class TestActivation: @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) def test_swiglu_fp8(fp8_dtype): """ Test FP8 SwiGLU Forward """ - a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 + a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 fp8_meta = create_fp8_meta() swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - swiglu_out = cast_from_fp8(swiglu_out_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32) + swiglu_out = cast_from_fp8( + swiglu_out_fp8, + fp8_meta, + FP8FwdTensors.GEMM1_INPUT, + itype=fp8_dtype, + otype=tex.DType.kFloat32, + ) swiglu_ref = swiglu_pd(a) @@ -333,10 +355,10 @@ class TestActivation: Test SwiGLU Backward """ # y = SwiGLU(x), calculate ref - x = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1 + x = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 x.stop_gradient = False y = swiglu_pd(x) - y_grad = paddle.rand(shape=(16, 16), dtype='bfloat16') * 2 - 1 + y_grad = paddle.rand(shape=(16, 16), dtype="bfloat16") * 2 - 1 paddle.autograd.backward([y], [y_grad], True) # calculate fp8 x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16) @@ -350,17 +372,18 @@ class TestGemm: """ @staticmethod - @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 GEMM requires Ampere+ GPU") - @pytest.mark.parametrize('m,n,k', GEMM_CASES) + @pytest.mark.skipif( + paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU" + ) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) def test_bf16(m, n, k): """ Test "TN" BF16 GEMM """ - a = paddle.rand(shape=(m, k), dtype='bfloat16') - b = paddle.rand(shape=(n, k), dtype='bfloat16') + a = paddle.rand(shape=(m, k), dtype="bfloat16") + b = paddle.rand(shape=(n, k), dtype="bfloat16") - workspace = paddle.zeros(shape=[33_554_432], dtype='uint8') + workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") ref_out = paddle.matmul(a, b.T) # CublasLt inside tex.te_gemm assumes inputs are column major. @@ -368,37 +391,51 @@ class TestGemm: # transpose of X. # Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T, # which is equivalent to a@b^T = C in row major. - actual_out, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", - None, None, False) + actual_out, _, _ = gemm( + b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", None, None, False + ) assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5) @staticmethod - @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 GEMM requires Ampere+ GPU") - @pytest.mark.parametrize('m,n,k', GEMM_CASES) + @pytest.mark.skipif( + paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU" + ) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) def test_bf16_inplace(m, n, k): """ Test "TN" BF16 GEMM, with accumulate=True """ min_val = -16 max_val = 16 - a = paddle.rand(shape=(m, k), dtype='bfloat16') - b = paddle.rand(shape=(n, k), dtype='bfloat16') - c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), 'bfloat16') - workspace = paddle.zeros(shape=[33_554_432], dtype='uint8') + a = paddle.rand(shape=(m, k), dtype="bfloat16") + b = paddle.rand(shape=(n, k), dtype="bfloat16") + c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), "bfloat16") + workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") ref_out = c + paddle.matmul(a, b.T) actual_out = paddle.clone(c) - _, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, True, "TN", actual_out, - None, False) + _, _, _ = gemm( + b, + a, + paddle.bfloat16, + workspace, + False, + None, + False, + True, + "TN", + actual_out, + None, + False, + ) assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2) @staticmethod @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', GEMM_CASES) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) def test_fp8_randint(m, n, k): """ Test "TN" FP8 GEMM @@ -409,17 +446,26 @@ class TestGemm: out_dtype = paddle.float32 fp8_meta = create_fp8_meta(num_gemms=1) - a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), 'float32') + a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), "float32") a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), 'float32') + b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), "float32") b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype) - workspace = paddle.zeros(shape=[33_554_432], dtype='uint8') + workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") ref_out = paddle.matmul(a, b.T) - actual_out, _ = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, out_dtype, workspace) + actual_out, _ = fp8_gemm( + b_casted, + fp8_meta.scale_inv, + FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype, + a_casted, + fp8_meta.scale_inv, + FP8FwdTensors.GEMM1_INPUT, + fp8_dtype, + out_dtype, + workspace, + ) assert_allclose(actual_out, ref_out) @@ -434,14 +480,12 @@ class TestLayerNorm: """ Calculate reference using paddle layer_norm op """ - y = paddle.nn.functional.layer_norm(x=x, - normalized_shape=x.shape[1:], - weight=gamma, - bias=beta, - epsilon=eps) + y = paddle.nn.functional.layer_norm( + x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps + ) mean = paddle.mean(x, axis=-1) var = paddle.var(x, axis=-1) - inv_var = paddle.sqrt(1. / var) + inv_var = paddle.sqrt(1.0 / var) return y, mean, inv_var @staticmethod @@ -453,11 +497,9 @@ class TestLayerNorm: gamma.stop_gradient = False beta.stop_gradient = False - y = paddle.nn.functional.layer_norm(x=x, - normalized_shape=x.shape[1:], - weight=gamma, - bias=beta, - epsilon=eps) + y = paddle.nn.functional.layer_norm( + x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps + ) paddle.autograd.backward([y], [dy], True) @@ -469,9 +511,9 @@ class TestLayerNorm: """ N, H = (16, 32) eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype='bfloat16') - gamma = paddle.uniform(shape=(H,), dtype='bfloat16') - beta = paddle.uniform(shape=(H,), dtype='bfloat16') + x = paddle.uniform(shape=(N, H), dtype="bfloat16") + gamma = paddle.uniform(shape=(H,), dtype="bfloat16") + beta = paddle.uniform(shape=(H,), dtype="bfloat16") y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16) @@ -490,9 +532,9 @@ class TestLayerNorm: N, H = (16, 32) eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype='float32') - gamma = paddle.uniform(shape=(H,), dtype='float32') - beta = paddle.uniform(shape=(H,), dtype='float32') + x = paddle.uniform(shape=(N, H), dtype="float32") + gamma = paddle.uniform(shape=(H,), dtype="float32") + beta = paddle.uniform(shape=(H,), dtype="float32") fp8_tensor = FP8FwdTensors.GEMM1_INPUT fp8_meta = create_fp8_meta() @@ -513,10 +555,10 @@ class TestLayerNorm: """ N, H = (16, 32) eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype='bfloat16') - dy = paddle.uniform(shape=(N, H), dtype='bfloat16') - gamma = paddle.uniform(shape=(H,), dtype='bfloat16') - beta = paddle.uniform(shape=(H,), dtype='bfloat16') + x = paddle.uniform(shape=(N, H), dtype="bfloat16") + dy = paddle.uniform(shape=(N, H), dtype="bfloat16") + gamma = paddle.uniform(shape=(H,), dtype="bfloat16") + beta = paddle.uniform(shape=(H,), dtype="bfloat16") dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy) @@ -563,8 +605,8 @@ class TestRMSNorm: """ N, H = (16, 32) eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype='bfloat16') - gamma = paddle.uniform(shape=(H,), dtype='bfloat16') + x = paddle.uniform(shape=(N, H), dtype="bfloat16") + gamma = paddle.uniform(shape=(H,), dtype="bfloat16") y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16) @@ -581,8 +623,8 @@ class TestRMSNorm: N, H = (16, 32) eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype='float32') - gamma = paddle.uniform(shape=(H,), dtype='float32') + x = paddle.uniform(shape=(N, H), dtype="float32") + gamma = paddle.uniform(shape=(H,), dtype="float32") fp8_tensor = FP8FwdTensors.GEMM1_INPUT fp8_meta = create_fp8_meta() @@ -602,9 +644,9 @@ class TestRMSNorm: """ N, H = (16, 32) eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype='bfloat16') - dy = paddle.uniform(shape=(N, H), dtype='bfloat16') - gamma = paddle.uniform(shape=(H,), dtype='bfloat16') + x = paddle.uniform(shape=(N, H), dtype="bfloat16") + dy = paddle.uniform(shape=(N, H), dtype="bfloat16") + gamma = paddle.uniform(shape=(H,), dtype="bfloat16") dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy) @@ -620,7 +662,7 @@ class TestFusedAttn: Test fused attention operators """ - def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode='self_attn', is_causal_masking=False): + def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode="self_attn", is_causal_masking=False): """ set test input """ @@ -682,10 +724,10 @@ class TestFusedAttn: assert attn_mode == "self_attn", "only support causal masking for self attention" for i in range(0, self.batch_size): for j in range(self.q_actual_seqlen[i]): - self.attn_mask[i, :, j, :j + 1] = 0 + self.attn_mask[i, :, j, : j + 1] = 0 else: for i in range(0, self.batch_size): - self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0 + self.attn_mask[i, :, : self.q_actual_seqlen[i], : self.kv_actual_seqlen[i]] = 0 dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size)) self.dout = paddle.to_tensor(dout, dtype=self.dtype) @@ -696,9 +738,9 @@ class TestFusedAttn: k_tensor = paddle.to_tensor(self.kv, stop_gradient=False) v_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] + q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] + k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] + v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] qk_out = paddle.matmul( x=q_out * self.scaling_factor, @@ -707,10 +749,10 @@ class TestFusedAttn: transpose_y=True, ) - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast('bool') + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast("bool") attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype) attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out) - attn_mask_out = paddle.cast(attn_mask_out, 'float32') + attn_mask_out = paddle.cast(attn_mask_out, "float32") softmax_out = F.softmax(attn_mask_out) softmax_out = paddle.cast(softmax_out, self.dtype) @@ -725,7 +767,7 @@ class TestFusedAttn: else: qkv_out = paddle.matmul(softmax_out, v_out) - out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d] + out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d] paddle.autograd.backward( [out], @@ -738,17 +780,17 @@ class TestFusedAttn: paddle.disable_static(place=paddle.CUDAPlace(0)) if self.attn_mode == "self_attn": - qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d] + qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d] qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False) else: q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d] + kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d] kv_tensor = paddle.to_tensor(kv, stop_gradient=False) q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) - qkv_layout = ("bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd") + qkv_layout = "bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd" fused_attention_backend = get_fused_attention_backend( num_heads=self.num_heads, num_gqa_groups=self.num_heads, @@ -764,7 +806,7 @@ class TestFusedAttn: qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None - if self.attn_mode == 'self_attn': + if self.attn_mode == "self_attn": out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked( qkv_tensor, q_cu_seqlen_tensor, @@ -776,7 +818,8 @@ class TestFusedAttn: attn_scale=self.scaling_factor, dropout=self.dropout_prob, set_zero=False, - attn_mask_type="causal" if self.is_causal_masking else "padding") + attn_mask_type="causal" if self.is_causal_masking else "padding", + ) dqkv, _ = fused_attn_bwd_qkvpacked( qkv_tensor, q_cu_seqlen_tensor, @@ -790,11 +833,12 @@ class TestFusedAttn: attn_scale=self.scaling_factor, dropout=self.dropout_prob, set_zero=False, - attn_mask_type="causal" if self.is_causal_masking else "padding") + attn_mask_type="causal" if self.is_causal_masking else "padding", + ) q_grad = dqkv[:, :, 0, :, :] k_grad = dqkv[:, :, 1, :, :] v_grad = dqkv[:, :, 2, :, :] - else: # attn_mode == 'cross_attn' + else: # attn_mode == 'cross_attn' out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked( q_tensor, kv_tensor, @@ -808,22 +852,25 @@ class TestFusedAttn: Bias=None, attn_scale=self.scaling_factor, dropout=self.dropout_prob, - set_zero=False) - dq, dkv, _ = fused_attn_bwd_kvpacked(q_tensor, - kv_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - fused_attention_backend=fused_attention_backend, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False) + set_zero=False, + ) + dq, dkv, _ = fused_attn_bwd_kvpacked( + q_tensor, + kv_tensor, + q_cu_seqlen_tensor, + kv_cu_seqlen_tensor, + rng_state, + out, + self.dout, + softmax_aux_tensor, + fused_attention_backend=fused_attention_backend, + max_seqlen_q=self.q_seqlen, + max_seqlen_kv=self.kv_seqlen, + qkv_dtype=qkv_dtype, + attn_scale=self.scaling_factor, + dropout=self.dropout_prob, + set_zero=False, + ) q_grad = dq k_grad = dkv[:, :, 0, :, :] v_grad = dkv[:, :, 1, :, :] @@ -871,7 +918,8 @@ class TestFusedAttn: dropout=self.dropout_prob, set_zero=False, qkv_layout=qkv_layout, - attn_mask_type="causal" if self.is_causal_masking else "padding") + attn_mask_type="causal" if self.is_causal_masking else "padding", + ) dq, dk, dv, _ = fused_attn_bwd( q_tensor, k_tensor, @@ -890,28 +938,29 @@ class TestFusedAttn: dropout=self.dropout_prob, set_zero=False, qkv_layout=qkv_layout, - attn_mask_type="causal" if self.is_causal_masking else "padding") + attn_mask_type="causal" if self.is_causal_masking else "padding", + ) return out, dq, dk, dv - @pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES) - @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) - @pytest.mark.parametrize('is_causal_masking', [True, False]) + @pytest.mark.parametrize("b, s, h, d", SELF_ATTN_CASES) + @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) + @pytest.mark.parametrize("is_causal_masking", [True, False]) def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): """ test self attention forward + backward """ if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", + num_heads=h, + num_gqa_groups=h, + q_seqlen=s, + kv_seqlen=s, + head_size=d, + dtype=dtype, + dropout=0.0, + qkv_layout="bs3hd", + bias_type="no_bias", + mask_type="causal" if is_causal_masking else "padding", ): pytest.skip("cuDNN fused attention is not supported") self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) @@ -922,23 +971,23 @@ class TestFusedAttn: assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES) - @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) + @pytest.mark.parametrize("b, s_q, s_kv, h, d", CROSS_ATTN_CASES) + @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype): """ test cross attention forward + backward """ if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s_q, - kv_seqlen=s_kv, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bshd_bs2hd", - bias_type="no_bias", - mask_type="padding", + num_heads=h, + num_gqa_groups=h, + q_seqlen=s_q, + kv_seqlen=s_kv, + head_size=d, + dtype=dtype, + dropout=0.0, + qkv_layout="bshd_bs2hd", + bias_type="no_bias", + mask_type="padding", ): pytest.skip("cuDNN fused attention is not supported") self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn") @@ -949,24 +998,24 @@ class TestFusedAttn: assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - @pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES) - @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) - @pytest.mark.parametrize('is_causal_masking', [True]) + @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES) + @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) + @pytest.mark.parametrize("is_causal_masking", [True]) def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): """ test flash attention forward + backward """ if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", + num_heads=h, + num_gqa_groups=h, + q_seqlen=s, + kv_seqlen=s, + head_size=d, + dtype=dtype, + dropout=0.0, + qkv_layout="bs3hd", + bias_type="no_bias", + mask_type="causal" if is_causal_masking else "padding", ): pytest.skip("cuDNN fused attention is not supported") self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) @@ -977,25 +1026,26 @@ class TestFusedAttn: assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - @pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES) - @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) - @pytest.mark.parametrize('is_causal_masking', [False, True]) - def test_fused_attn_with_separate_qkv_forward_backward(self, b, s, h, d, dtype, - is_causal_masking): + @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES) + @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) + @pytest.mark.parametrize("is_causal_masking", [False, True]) + def test_fused_attn_with_separate_qkv_forward_backward( + self, b, s, h, d, dtype, is_causal_masking + ): """ test flash attention forward + backward with separate qkv inputs """ if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", + num_heads=h, + num_gqa_groups=h, + q_seqlen=s, + kv_seqlen=s, + head_size=d, + dtype=dtype, + dropout=0.0, + qkv_layout="bshd_bshd_bshd", + bias_type="no_bias", + mask_type="causal" if is_causal_masking else "padding", ): pytest.skip("cuDNN fused attention is not supported") self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) @@ -1013,7 +1063,7 @@ class TestSoftmax: """ @staticmethod - @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) + @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_scaled_softmax_fwd_bwd(dtype): """test scaled softmax""" B, H, S = (16, 4, 32) @@ -1034,7 +1084,7 @@ class TestSoftmax: assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) @staticmethod - @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) + @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_scaled_masked_softmax_fwd_bwd(dtype): """test scaled masked softmax""" B, H, S = (16, 4, 32) @@ -1058,7 +1108,7 @@ class TestSoftmax: assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) @staticmethod - @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) + @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype): """test scaled upper triang masked softmax""" B, S = (16, 32) @@ -1068,7 +1118,7 @@ class TestSoftmax: x.stop_gradient = False dy = paddle.uniform(shape=(B, S, S), dtype=dtype) - mask = paddle.ones((S, S), dtype='int32') + mask = paddle.ones((S, S), dtype="int32") col_beg, col_end = 1, S for row in range(0, S): mask[row, col_beg:col_end] = 0 @@ -1087,7 +1137,7 @@ class TestSoftmax: assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3) -@pytest.mark.parametrize('update_weight_scale_inv', [True, False]) +@pytest.mark.parametrize("update_weight_scale_inv", [True, False]) def test_amax_and_scale_update(update_weight_scale_inv): """Test update_scale""" num_gemm = 6 @@ -1097,11 +1147,11 @@ def test_amax_and_scale_update(update_weight_scale_inv): fp8_max = recipe.fp8_format.value.max_fwd non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2)) - amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32') + amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32") rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0) rolled_history_ref[0] = 0.0 amax_tensor = paddle.max(amax_history_tensor, axis=0) - scale_tensor = paddle.ones(shape=[num_gemm], dtype='float32') + scale_tensor = paddle.ones(shape=[num_gemm], dtype="float32") def calc_ref(amax, scale, fp8_max, margin=0): """Calculate reference scale""" @@ -1110,12 +1160,12 @@ def test_amax_and_scale_update(update_weight_scale_inv): sf = paddle.where(paddle.isfinite(amax), sf, scale) return sf - scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.) + scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.0) if update_weight_scale_inv: - scale_inv_ref = 1. / scale_ref + scale_inv_ref = 1.0 / scale_ref else: scale_inv_ref = paddle.zeros_like(scale_tensor) - scale_inv_ref = paddle.where(non_weight_mask, 1. / scale_ref, scale_inv_ref) + scale_inv_ref = paddle.where(non_weight_mask, 1.0 / scale_ref, scale_inv_ref) # Placeholder scale_actual = paddle.zeros_like(scale_tensor) @@ -1123,13 +1173,15 @@ def test_amax_and_scale_update(update_weight_scale_inv): if update_weight_scale_inv: non_weight_mask = paddle.empty([0]) - tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor, - _scale=scale_actual, - _scale_inv=scale_inv_actual, - non_weight_mask=non_weight_mask, - fp8_dtype=int(fp8_dtype), - margin=0., - amax_compute="max") + tex.amax_and_scale_update_inplace( + _amax_history=amax_history_tensor, + _scale=scale_actual, + _scale_inv=scale_inv_actual, + non_weight_mask=non_weight_mask, + fp8_dtype=int(fp8_dtype), + margin=0.0, + amax_compute="max", + ) assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7) assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7) @@ -1141,8 +1193,8 @@ def test_update_latest_history(): num_gemm = 6 history_len = 1024 - amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32') - amax = paddle.rand(shape=[num_gemm], dtype='float32') + amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32") + amax = paddle.rand(shape=[num_gemm], dtype="float32") tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax) diff --git a/tests/paddle/test_parallel.py b/tests/paddle/test_parallel.py index 9c8f399e3e01283b64b070a382b7cc3b739bbfaa..f07d56d44bd9539ac228fe2046483b5bc158b002 100644 --- a/tests/paddle/test_parallel.py +++ b/tests/paddle/test_parallel.py @@ -22,7 +22,7 @@ class TestParallelLinear(TestDistributed): @unittest.skipIf(not gpu_has_fp8, reason) def test_linear_tp(self): """Tests linear with tensor parallel in BF16""" - self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_tp.py')) + self.run_2gpu(str(test_root / "parallel_tests" / "linear_tp.py")) class TestParallelLayerNormLinear(TestDistributed): @@ -32,7 +32,7 @@ class TestParallelLayerNormLinear(TestDistributed): @unittest.skipIf(not gpu_has_fp8, reason) def test_layernorm_linear_tp(self): """Tests layernorm_linear with tensor parallel in BF16""" - self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_linear_tp.py')) + self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_linear_tp.py")) class TestParallelLayerNormMLP(TestDistributed): @@ -42,7 +42,7 @@ class TestParallelLayerNormMLP(TestDistributed): @unittest.skipIf(not gpu_has_fp8, reason) def test_layernorm_mlp_tp(self): """Tests layernorm_mlp with tensor parallel in BF16""" - self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_mlp_tp.py')) + self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_mlp_tp.py")) class TestAmaxReduction(TestDistributed): @@ -52,7 +52,7 @@ class TestAmaxReduction(TestDistributed): @unittest.skipIf(not gpu_has_fp8, reason) def test_amax_reduction(self): """Tests amax reduction""" - self.run_2gpu(str(test_root / 'parallel_tests' / 'amax_reduction.py')) + self.run_2gpu(str(test_root / "parallel_tests" / "amax_reduction.py")) class TestPipelineParallel(TestDistributed): @@ -62,7 +62,7 @@ class TestPipelineParallel(TestDistributed): @unittest.skipIf(not gpu_has_fp8, reason) def test_pipeline_parallel(self): """Tests pipeline parallel""" - self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_pp.py')) + self.run_2gpu(str(test_root / "parallel_tests" / "linear_pp.py")) class TestGroupSharding(TestDistributed): @@ -72,7 +72,7 @@ class TestGroupSharding(TestDistributed): @unittest.skipIf(not gpu_has_fp8, reason) def test_group_sharding(self): """Tests group sharding""" - self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py')) + self.run_2gpu(str(test_root / "parallel_tests" / "group_sharding.py")) class TestParallelAttention(TestDistributed): @@ -82,7 +82,7 @@ class TestParallelAttention(TestDistributed): @unittest.skipIf(not gpu_has_fp8, reason) def test_attention_tp(self): """Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16""" - self.run_2gpu(str(test_root / 'parallel_tests' / 'attention_tp.py')) + self.run_2gpu(str(test_root / "parallel_tests" / "attention_tp.py")) class TestParallelTransformerLayer(TestDistributed): @@ -92,8 +92,8 @@ class TestParallelTransformerLayer(TestDistributed): @unittest.skipIf(not gpu_has_fp8, reason) def test_transformer_tp(self): """Tests Transformer Layer with tensor parallel in BF16""" - self.run_2gpu(str(test_root / 'parallel_tests' / 'transformer_tp.py')) + self.run_2gpu(str(test_root / "parallel_tests" / "transformer_tp.py")) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/paddle/test_recompute.py b/tests/paddle/test_recompute.py index 769921c6f57045d82c63e34e26c1c5b7f048f3fb..02dddad2107483c8528bca6327fc0dcde0de2691 100644 --- a/tests/paddle/test_recompute.py +++ b/tests/paddle/test_recompute.py @@ -17,7 +17,7 @@ is_fp8_supported, reason = is_fp8_available() @pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize('use_reentrant', [False, True]) +@pytest.mark.parametrize("use_reentrant", [False, True]) def test_transformer_encoder_recompute(use_reentrant): """ Test TransformerLayer encoder recompute @@ -29,17 +29,17 @@ def test_transformer_encoder_recompute(use_reentrant): """Launch training in subprocess and check output""" try: cmd = [ - 'python', - str(test_root / 'recompute_tests' / 'recompute_transformer_encoder.py'), + "python", + str(test_root / "recompute_tests" / "recompute_transformer_encoder.py"), str(int(enable_recompute)), - str(int(use_reentrant)) + str(int(use_reentrant)), ] result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True) print(result) - loss_match = re.search(r'Loss:\s+(-?\d+\.\d+)', result) - memory_match = re.search(r'Peak memory:\s+(\d+)', result) + loss_match = re.search(r"Loss:\s+(-?\d+\.\d+)", result) + memory_match = re.search(r"Peak memory:\s+(\d+)", result) loss_value = float(loss_match.group(1)) memory_value = int(memory_match.group(1)) diff --git a/tests/paddle/test_sanity_import.py b/tests/paddle/test_sanity_import.py index 8245de4bcf0c028aa233c0f10ff261146334984c..9b38d543da7a687b8c952b110953686025b39065 100644 --- a/tests/paddle/test_sanity_import.py +++ b/tests/paddle/test_sanity_import.py @@ -3,4 +3,5 @@ # See LICENSE for license information. import transformer_engine.paddle + print("OK") diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py index 8d65711939ab532da5dd35b91027e4faaad32177..572af66ff9979b9dc5a0a66c111a8411532c32d0 100644 --- a/tests/paddle/utils.py +++ b/tests/paddle/utils.py @@ -11,7 +11,7 @@ import paddle from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker -import transformer_engine # pylint: disable=unused-import +import transformer_engine # pylint: disable=unused-import from transformer_engine.paddle.constants import ( TE_DType, AttnBiasType, @@ -19,7 +19,9 @@ from transformer_engine.paddle.constants import ( FusedAttnBackend, ) from transformer_engine.paddle.fp8 import FP8TensorMeta -from transformer_engine import transformer_engine_paddle as tex # pylint: disable=wrong-import-order +from transformer_engine import ( + transformer_engine_paddle as tex, +) # pylint: disable=wrong-import-order def create_fp8_meta(num_gemms=1, amax_history_len=10): @@ -31,18 +33,14 @@ def create_fp8_meta(num_gemms=1, amax_history_len=10): return fp8_meta -def assert_allclose(actual, - desired, - rtol=1e-05, - atol=1e-08, - equal_nan=True, - err_msg='', - verbose=True): +def assert_allclose( + actual, desired, rtol=1e-05, atol=1e-08, equal_nan=True, err_msg="", verbose=True +): """Compare two input paddle tensors""" if isinstance(actual, paddle.Tensor): - actual = paddle.cast(actual, 'float32') + actual = paddle.cast(actual, "float32") if isinstance(desired, paddle.Tensor): - desired = paddle.cast(desired, 'float32') + desired = paddle.cast(desired, "float32") if len(actual.shape) == 0: actual = actual.item() desired = desired.item() @@ -54,8 +52,9 @@ def assert_allclose(actual, def assert_shape(inp, expected_shape): """Assert the shape of input tensor equals to expected shape""" - assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \ - f"actual tensor shape: {inp.shape}" + assert ( + inp.shape == expected_shape + ), f"Expected tensor shape: {expected_shape} != actual tensor shape: {inp.shape}" def is_devices_enough(required): @@ -91,12 +90,21 @@ def set_random_seed(seed): np.random.seed(seed + 100 * pp_rank) seed_offset = seed + 1024 + paddle.distributed.get_world_size() - global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) + - sharding_rank * (mp_size * pp_size * dp_size)) + global_seed = ( + seed_offset + + pp_rank * (mp_size) + + dp_rank * (mp_size * pp_size) + + sharding_rank * (mp_size * pp_size * dp_size) + ) seed_offset += paddle.distributed.get_world_size() - local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) + - sharding_rank * (mp_size * pp_size * dp_size)) + local_seed = ( + seed_offset + + mp_rank + + pp_rank * (mp_size) + + dp_rank * (mp_size * pp_size) + + sharding_rank * (mp_size * pp_size * dp_size) + ) tracker = get_rng_state_tracker() # tracker.reset() diff --git a/tests/pytorch/distributed/print_logs.py b/tests/pytorch/distributed/print_logs.py index 20062e064fbad142d239cd2453dd31f190686a53..6c25db49452ba5c74668a613784119f401499a8b 100644 --- a/tests/pytorch/distributed/print_logs.py +++ b/tests/pytorch/distributed/print_logs.py @@ -112,13 +112,17 @@ def perf_and_loss_plots(): lm_loss_data.append(lm_data["loss"]) lm_perf_data.append(lm_data["perf"]) save_plot( - model_config + " loss", legend, - lm_loss_data, model_config + "_loss.png", + model_config + " loss", + legend, + lm_loss_data, + model_config + "_loss.png", "LM-Loss", ) save_plot( model_config + " perf", - legend, lm_perf_data, model_config + "_perf.png", + legend, + lm_perf_data, + model_config + "_perf.png", "Time per step (ms)", ) diff --git a/tests/pytorch/distributed/test_convergence.py b/tests/pytorch/distributed/test_convergence.py index 2a4c6a7282480be15ba0fdbb324cb424d259305b..fa3ba1e3f38830ec96b56392170d897f5c56da32 100644 --- a/tests/pytorch/distributed/test_convergence.py +++ b/tests/pytorch/distributed/test_convergence.py @@ -68,7 +68,9 @@ def get_filename( config = f"gpt3_{model}_dp{dp}_tp{tp}_pp{pp}_sp{sp}" config_dir = os.path.join(mlm_log_dir, config) os.makedirs(config_dir, exist_ok=True) - fname = f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt" + fname = ( + f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt" + ) return os.path.join(config_dir, fname) @@ -106,4 +108,5 @@ def test_distributed(dtype, fp8_recipe, dp, tp, pp, sp, use_te, model): TRANSFORMER_IMPL="transformer_engine" if use_te else "local", **asdict(model_configs[model]), ), - check=True) + check=True, + ) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index c35fc339318c8b1a8ec205170e8d9cf54522e86d..9b9b7686c250019a4b4ac74e1c0c9dd40bbc3a42 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -9,9 +9,10 @@ from transformer_engine.pytorch.attention import DotProductAttention import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn -dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16} +dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} -def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend='FlashAttention'): + +def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"): """Test DotProductAttention module with context parallelism""" os.environ["NVTE_FLASH_ATTN"] = "0" @@ -22,11 +23,13 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" config = model_configs_fused_attn[model] - if qkv_format == 'thd' and (config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"): + if qkv_format == "thd" and ( + config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias" + ): return - rank = int(os.getenv('RANK', '0')) - world_size = int(os.getenv('WORLD_SIZE', '1')) + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) if dist.is_initialized(): world_size = dist.get_world_size() @@ -38,51 +41,76 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= print(f"[INFO] world_size:{world_size}, rank:{rank}") - dist.init_process_group(backend='nccl', world_size=world_size, rank=rank) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) # create flash attn comm group for CP cp_comm_ranks = range(world_size) - assert(rank in cp_comm_ranks) - cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl') + assert rank in cp_comm_ranks + cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!" + assert config.attn_mask_type in [ + "causal", + "no_mask", + ], f"{config.attn_mask_type} is an unsupported attention mask type!" - if kernel_backend == 'FusedAttention' and qkv_format == 'thd': - if 'causal' in config.attn_mask_type: - config.attn_mask_type = 'padding_causal' + if kernel_backend == "FusedAttention" and qkv_format == "thd": + if "causal" in config.attn_mask_type: + config.attn_mask_type = "padding_causal" else: - config.attn_mask_type = 'padding' + config.attn_mask_type = "padding" # instantiate core attn module - core_attn = DotProductAttention(config.num_heads, - config.head_dim, - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type) + core_attn = DotProductAttention( + config.num_heads, + config.head_dim, + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + ) core_attn = core_attn.cuda() # create flash attn inputs if qkv_format == "bshd": q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim) - kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim) - attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim) + kv_input_shape = ( + config.batch_size, + config.max_seqlen_kv, + config.num_gqa_groups, + config.head_dim, + ) + attn_output_shape = ( + config.batch_size, + config.max_seqlen_q, + config.num_heads * config.head_dim, + ) cu_seqlens_q = None cu_seqlens_kv = None elif qkv_format == "sbhd": q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim) - kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim) - attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim) + kv_input_shape = ( + config.max_seqlen_kv, + config.batch_size, + config.num_gqa_groups, + config.head_dim, + ) + attn_output_shape = ( + config.max_seqlen_q, + config.batch_size, + config.num_heads * config.head_dim, + ) cu_seqlens_q = None cu_seqlens_kv = None elif qkv_format == "thd": - seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) + seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to( + torch.int32 + ) seqlens_q = seqlens_q - seqlens_q % (world_size * 2) cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)]) cu_seqlens_kv = cu_seqlens_q q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim) kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim) - attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim) + attn_output_shape = (cu_seqlens_q[-1], config.num_heads * config.head_dim) cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda() cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda() else: @@ -111,7 +139,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= for x in [q, k, v]: x.requires_grad = True out = core_attn( - q, k, v, + q, + k, + v, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, cu_seqlens_q=cu_seqlens_q, @@ -120,17 +150,28 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= out.backward(dout) # run core_attn wit CP - q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])] + q_, k_, v_, dout_, *rest = [ + x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) + ] bias_ = rest[0] if len(rest) else None if qkv_format == "bshd" or qkv_format == "sbhd": - seq_dim = qkv_format.index('s') - q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \ - for x in [q_, k_, v_, dout_]] - seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device) + seq_dim = qkv_format.index("s") + q_, k_, v_, dout_ = [ + x.view( + *x.shape[:seq_dim], + 2 * world_size, + x.shape[seq_dim] // (2 * world_size), + *x.shape[(seq_dim + 1) :], + ) + for x in [q_, k_, v_, dout_] + ] + seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] - q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]] + q_, k_, v_, dout_ = [ + x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] + ] elif qkv_format == "thd": - seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank) + seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank) seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank) q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] @@ -140,14 +181,18 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= assert False, f"{qkv_format} is an unsupported qkv_format!" q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: - bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1]) + bias_ = bias_.view( + *bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1] + ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) - max_seqlen_q = config.max_seqlen_q + max_seqlen_q = config.max_seqlen_q max_seqlen_kv = config.max_seqlen_kv out_ = core_attn( - q_, k_, v_, + q_, + k_, + v_, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias_, cu_seqlens_q=cu_seqlens_q, @@ -158,23 +203,32 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= out_.backward(dout_) for x in [out_, q_.grad, k_.grad, v_.grad]: - assert(torch.all(~torch.isnan(x))) - assert(torch.all(~torch.isinf(x))) + assert torch.all(~torch.isnan(x)) + assert torch.all(~torch.isinf(x)) # compare results with and without CP tols = dict(atol=5e-3, rtol=5e-3) - if dtype == 'bf16': + if dtype == "bf16": if config.num_heads == config.num_gqa_groups: tols = dict(atol=2.5e-2, rtol=2.5e-2) else: tols = dict(atol=3.5e-2, rtol=3.5e-2) if qkv_format == "bshd" or qkv_format == "sbhd": - dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \ - for x in [q.grad, k.grad, v.grad, out]] + dq, dk, dv, out = [ + x.view( + *x.shape[:seq_dim], + 2 * world_size, + x.shape[seq_dim] // (2 * world_size), + *x.shape[(seq_dim + 1) :], + ) + for x in [q.grad, k.grad, v.grad, out] + ] dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] - dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \ - for x in [q_.grad, k_.grad, v_.grad, out_]] + dq_, dk_, dv_, out_ = [ + x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) + for x in [q_.grad, k_.grad, v_.grad, out_] + ] elif qkv_format == "thd": dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]] @@ -208,9 +262,11 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= else: assert False, f"{qkv_format} is an unsupported qkv_format!" + def main(**kwargs): run_dpa_with_cp(**kwargs) + if __name__ == "__main__": - kwargs = dict(arg.split('=') for arg in sys.argv[2:]) + kwargs = dict(arg.split("=") for arg in sys.argv[2:]) main(**kwargs) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index ccd82c1d067e404eee79b83b4654275597857284..cca515b63d0b7ba70f737a320b8b65488d9169dc 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -91,10 +91,10 @@ class ModelConfig: self.max_seqlen_q = max_seqlen_q self.max_seqlen_kv = max_seqlen_kv self.dropout_p = dropout_p - self.attn_mask_type = attn_mask_type - self.attn_bias_type = attn_bias_type - self.alibi_type = alibi_type - self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" + self.attn_mask_type = attn_mask_type + self.attn_bias_type = attn_bias_type + self.alibi_type = alibi_type + self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" self.num_layers = num_layers self.bias_shape = bias_shape @@ -184,28 +184,29 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool: return False return True + def _is_unfused_attention_supported( config: ModelConfig, qkv_format: str, - ) -> bool: +) -> bool: """Check if UnfusedDotProductAttention supports a model configuration""" - if ("padding" in config.attn_mask_type): + if "padding" in config.attn_mask_type: return False - if ("causal" in config.attn_mask_type and config.attn_type == 'cross'): + if "causal" in config.attn_mask_type and config.attn_type == "cross": return False - if qkv_format == 'thd': + if qkv_format == "thd": return False return True model_configs_base = { # test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend - "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 - "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0 - "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1 - "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 - "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference - "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference + "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 + "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0 + "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1 + "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 + "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference + "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference } @@ -221,13 +222,13 @@ def get_swa(seq_q, seq_kv, w=None): if w is None: w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda") m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda") - mu = torch.triu(m, diagonal=seq_kv-seq_q-w[0]) - ml = torch.tril(mu, diagonal=seq_kv-seq_q+w[1]) - ml = ~ ml + mu = torch.triu(m, diagonal=seq_kv - seq_q - w[0]) + ml = torch.tril(mu, diagonal=seq_kv - seq_q + w[1]) + ml = ~ml return w, ml -@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model", model_configs_base.keys()) @@ -236,8 +237,9 @@ def get_swa(seq_q, seq_kv, w=None): @pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("swa", [False]) @pytest.mark.parametrize("pad_between_seqs", [False]) -def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, - workspace_opt, qkv_layout, swa, pad_between_seqs): +def test_dot_product_attention( + dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs +): """Test DotProductAttention module""" # Get configs @@ -251,36 +253,43 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, else: qkv_layout = "sbhd_sb2hd" if "3" in qkv_layout and config.attn_type == "cross": - pytest.skip( - "No need to test this layout for cross attention" - ) + pytest.skip("No need to test this layout for cross attention") # Skip if only unfused backend is supported - qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format) if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( - config, dtype, qkv_layout=qkv_layout, + config, + dtype, + qkv_layout=qkv_layout, ) if swa: fused_attn_supported = False flash_attn_supported = _is_flash_attention_supported(config) if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") - if (qkv_format == 'thd' and 'padding' not in config.attn_mask_type): + if qkv_format == "thd" and "padding" not in config.attn_mask_type: pytest.skip("THD layout requires padding/padding_causal mask type.") # d=256 is supported by cuDNN 9.0+ for inference but not training - is_training = (config.head_dim <= 128) + is_training = config.head_dim <= 128 # UnfusedDotProductAttention backend if unfused_attn_supported: if swa: attn_mask_type = config.attn_mask_type config.attn_mask_type = "arbitrary" unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( - dtype, config, "UnfusedDotProductAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "UnfusedDotProductAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) if swa: config.attn_mask_type = attn_mask_type @@ -289,51 +298,79 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, if fused_attn_supported: if len(fused_attn_backend) == 1: fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( - dtype, config, "FusedAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) if len(fused_attn_backend) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( - dtype, config, "FusedAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( - dtype, config, "FusedAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) # FlashAttention backend if flash_attn_supported: flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( - dtype, config, "FlashAttention", - ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, + dtype, + config, + "FlashAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + swa, + pad_between_seqs, + is_training, ) if unfused_attn_supported and fused_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs fused attn") torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) - for i,_ in enumerate(unfused_attn_bwd): + for i, _ in enumerate(unfused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) - for i,_ in enumerate(flash_attn_bwd): + for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) - for i,_ in enumerate(flash_attn_bwd): + for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) if fused_attn_supported and len(fused_attn_backend) == 2: logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1") torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) - for i,_ in enumerate(fused_attn_bwd): + for i, _ in enumerate(fused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) -@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model", ["base_1_1", "base_2_1"]) @@ -344,22 +381,22 @@ def test_dpa_checkpoint(dtype, model_configs, model): model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias - "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), - "mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"), + "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), + "mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), + "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), + "mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), + "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), + "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"), "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), } -@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_mask]) @pytest.mark.parametrize("model", model_configs_mask.keys()) @@ -370,34 +407,48 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { # test: b, h, hg, d, sq, skv, p, mask, bias - "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), - "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"), - "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), - "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"), - "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped - "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped - "bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped - "bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped - "bias_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"), # skipped - "bias_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"), # skipped - "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped - "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped - "bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"), - "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), - "bias_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"), # skipped - "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"), - "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped - "bias_4_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"), # skipped - "bias_4_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), # skipped - "bias_4_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"), # skipped - "bias_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), # skipped - "bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped - "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped + "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), + "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"), + "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), + "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"), + "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped + "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped + "bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped + "bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped + "bias_2_2": ModelConfig( + 4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias" + ), # skipped + "bias_2_3": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias" + ), # skipped + "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped + "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped + "bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), + "bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"), + "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), + "bias_3_3": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias" + ), # skipped + "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"), + "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped + "bias_4_0": ModelConfig( + 4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias" + ), # skipped + "bias_4_1": ModelConfig( + 2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias" + ), # skipped + "bias_4_2": ModelConfig( + 4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias" + ), # skipped + "bias_4_3": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias" + ), # skipped + "bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped + "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped } -@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_bias]) @pytest.mark.parametrize("model", model_configs_bias.keys()) @@ -408,23 +459,38 @@ def test_dpa_bias(dtype, model_configs, model): model_configs_bias_shapes = { # test: b, h, hg, d, sq, skv, p, - "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, - # mask, bias, bias_shape, - "no_mask", "post_scale_bias", bias_shape='11ss'), - "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, - "no_mask", "post_scale_bias", bias_shape='1hss'), - "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, - "no_mask", "post_scale_bias", bias_shape='b1ss'), - "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, - "no_mask", "post_scale_bias", bias_shape='bhss'), - "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, - "causal", "alibi", bias_shape='1hss', alibi_type='custom'), - "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, - "causal", "alibi", bias_shape='bhss', alibi_type='custom'), + "bias_1_0": ModelConfig( + 4, + 16, + 16, + 64, + 128, + 128, + 0.0, + # mask, bias, bias_shape, + "no_mask", + "post_scale_bias", + bias_shape="11ss", + ), + "bias_1_1": ModelConfig( + 2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss" + ), + "bias_1_2": ModelConfig( + 4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss" + ), + "bias_1_3": ModelConfig( + 2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss" + ), + "bias_1_4": ModelConfig( + 4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom" + ), + "bias_1_5": ModelConfig( + 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom" + ), } -@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_bias_shapes]) @pytest.mark.parametrize("model", model_configs_bias_shapes.keys()) @@ -435,10 +501,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), - "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), + "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), + "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), } @@ -453,10 +519,14 @@ def test_dpa_sliding_window(dtype, model_configs, model): model_configs_alibi_slopes = { # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type - "alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"), - "alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"), - "alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"), - "alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"), + "alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"), + "alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"), + "alibi_2_0": ModelConfig( + 2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom" + ), + "alibi_2_1": ModelConfig( + 1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom" + ), } @@ -470,27 +540,35 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): qkv_layouts = [ - 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', - 'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd', - ] + "sb3hd", + "sbh3d", + "sbhd_sb2hd", + "sbhd_sbh2d", + "sbhd_sbhd_sbhd", + "bs3hd", + "bsh3d", + "bshd_bs2hd", + "bshd_bsh2d", + "bshd_bshd_bshd", +] model_configs_layout = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), - "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), - "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), + "layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), + "layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), + "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), + "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), + "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), - "layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), - "layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"), + "layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), + "layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"), } -@pytest.mark.skipif(get_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_layout]) @pytest.mark.parametrize("model", model_configs_layout.keys()) @@ -500,26 +578,28 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False) -qkv_layouts_thd = ['t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd'] +qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), - "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), + "layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), + "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), + "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), + "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), + "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), + "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), } -@pytest.mark.skipif(get_cudnn_version() < (9,0,0), reason="cuDNN 9.0.0+ is required.") -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+.") +@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.") +@pytest.mark.skipif( + get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+." +) @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_layout_thd]) @pytest.mark.parametrize("model", model_configs_layout_thd.keys()) @@ -527,24 +607,26 @@ model_configs_layout_thd = { def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" pad_between_seqs = False - test_dot_product_attention(dtype, model_configs, model, False, True, - qkv_layout, False, pad_between_seqs) + test_dot_product_attention( + dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs + ) pad_between_seqs = True - test_dot_product_attention(dtype, model_configs, model, False, True, - qkv_layout, False, pad_between_seqs) + test_dot_product_attention( + dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs + ) def _run_dot_product_attention( - dtype: torch.dtype, - config: ModelConfig, - backend: str, - ckpt_attn: bool, - qkv_layout: str, - workspace_opt: bool, - swa: bool, - pad_between_seqs: bool, - is_training: bool, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + dtype: torch.dtype, + config: ModelConfig, + backend: str, + ckpt_attn: bool, + qkv_layout: str, + workspace_opt: bool, + swa: bool, + pad_between_seqs: bool, + is_training: bool, +) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run DotProductAttention module with one forward pass and one backward pass""" # Set RNG and environment varables @@ -558,22 +640,27 @@ def _run_dot_product_attention( os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" # Create seqlens - qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) - if "padding" in config.attn_mask_type or qkv_format == 'thd': - if config.attn_type == 'self': - seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size], - dtype=torch.int32, device="cuda") + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if "padding" in config.attn_mask_type or qkv_format == "thd": + if config.attn_type == "self": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) seqlens_kv = seqlens_q - if config.attn_type == 'cross': - seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size], - dtype=torch.int32, device="cuda") - seqlens_kv = torch.randint(1, config.max_seqlen_kv, [config.batch_size], - dtype=torch.int32, device="cuda") + if config.attn_type == "cross": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.randint( + 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" + ) else: - seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, - dtype=torch.int32, device="cuda") - seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, - dtype=torch.int32, device="cuda") + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) @@ -586,7 +673,7 @@ def _run_dot_product_attention( pad_len = [0] * config.batch_size if pad_between_seqs: max_pad_len = 3 - pad_len = torch.randint(0, max_pad_len+1, [config.batch_size], device="cuda") #3 + pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda") # 3 seqlens_q_after_pad = seqlens_q + pad_len seqlens_kv_after_pad = seqlens_kv + pad_len cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0) @@ -595,25 +682,58 @@ def _run_dot_product_attention( # Create attention mask if padding attention_mask = None if "padding" in config.attn_mask_type: - if config.attn_type == 'self': + if config.attn_type == "self": attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) for i in range(config.batch_size): - attention_mask_q = torch.cat([attention_mask_q, - torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i])) - .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0) + attention_mask_q = torch.cat( + [ + attention_mask_q, + torch.Tensor( + [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i]) + ) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) attention_mask = attention_mask_q.to(device="cuda") - if config.attn_type == 'cross': + if config.attn_type == "cross": attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) for i in range(config.batch_size): - attention_mask_q = torch.cat([attention_mask_q, - torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i])) - .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0) - attention_mask_kv = torch.cat([attention_mask_kv, torch.Tensor( - [False]*seqlens_kv[i] + [True]*(config.max_seqlen_kv-seqlens_kv[i])) - .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0) + attention_mask_q = torch.cat( + [ + attention_mask_q, + torch.Tensor( + [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i]) + ) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask_kv = torch.cat( + [ + attention_mask_kv, + torch.Tensor( + [False] * seqlens_kv[i] + + [True] * (config.max_seqlen_kv - seqlens_kv[i]) + ) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) attention_mask = ( - attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda")) + attention_mask_q.to(device="cuda"), + attention_mask_kv.to(device="cuda"), + ) window_size = None if swa: window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv) @@ -623,62 +743,84 @@ def _run_dot_product_attention( alibi_slopes = None if config.attn_bias_type == "alibi" and config.alibi_type == "custom": if config.bias_shape == "1hss": - alibi_slopes = torch.randn( - config.num_heads).abs().to(dtype=torch.float32, device="cuda") + alibi_slopes = ( + torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda") + ) if config.bias_shape == "bhss": - alibi_slopes = torch.randn( - config.batch_size, config.num_heads).abs().to(dtype=torch.float32, device="cuda") + alibi_slopes = ( + torch.randn(config.batch_size, config.num_heads) + .abs() + .to(dtype=torch.float32, device="cuda") + ) # Create input tensors dim_to_num = { - 'b' : config.batch_size, - 'sq' : config.max_seqlen_q, - 'skv': config.max_seqlen_kv, - 'h' : config.num_heads, - 'hg' : config.num_gqa_groups, - 'd' : config.head_dim, - 't' : cu_seqlens_q_after_pad[-1], - 'tg' : cu_seqlens_kv_after_pad[-1], - '3' : 3, - '2' : 2, - '1' : 1, - } + "b": config.batch_size, + "sq": config.max_seqlen_q, + "skv": config.max_seqlen_kv, + "h": config.num_heads, + "hg": config.num_gqa_groups, + "d": config.head_dim, + "t": cu_seqlens_q_after_pad[-1], + "tg": cu_seqlens_kv_after_pad[-1], + "3": 3, + "2": 2, + "1": 1, + } inp = [] inp_orig = [] - for i,layout in enumerate(qkv_layout.split('_')): - layout = '_'.join(layout) + for i, layout in enumerate(qkv_layout.split("_")): + layout = "_".join(layout) if i == 0: - layout = layout.replace('s', 'sq') + layout = layout.replace("s", "sq") else: - layout = layout.replace('s', 'skv') - layout = layout.replace('h', 'hg') - layout = layout.replace('t', 'tg') - tensor_shape = [dim_to_num[j] for j in layout.split('_')] + layout = layout.replace("s", "skv") + layout = layout.replace("h", "hg") + layout = layout.replace("t", "tg") + tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor_orig = tensor - if qkv_format == 'thd' and pad_between_seqs: - tensor_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) - if layout in ['t_h_d', 't_3_h_d', 't_h_3_d']: - for i in range(1, config.batch_size+1): - valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1]) - pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i]) - tensor[pad_range[0]:pad_range[1]] = 0.0 - tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0) - if layout in ['tg_hg_d', 'tg_2_hg_d', 'tg_hg_2_d']: - for i in range(1, config.batch_size+1): - valid_range = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1]) - pad_range = (cu_seqlens_kv_after_pad[i] - pad_len[i-1], cu_seqlens_kv_after_pad[i]) - tensor[pad_range[0]:pad_range[1]] = 0.0 - tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0) + if qkv_format == "thd" and pad_between_seqs: + tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]: + for i in range(1, config.batch_size + 1): + valid_range = ( + cu_seqlens_q_after_pad[i - 1], + cu_seqlens_q_after_pad[i] - pad_len[i - 1], + ) + pad_range = ( + cu_seqlens_q_after_pad[i] - pad_len[i - 1], + cu_seqlens_q_after_pad[i], + ) + tensor[pad_range[0] : pad_range[1]] = 0.0 + tensor_orig = torch.cat( + [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0 + ) + if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]: + for i in range(1, config.batch_size + 1): + valid_range = ( + cu_seqlens_kv_after_pad[i - 1], + cu_seqlens_kv_after_pad[i] - pad_len[i - 1], + ) + pad_range = ( + cu_seqlens_kv_after_pad[i] - pad_len[i - 1], + cu_seqlens_kv_after_pad[i], + ) + tensor[pad_range[0] : pad_range[1]] = 0.0 + tensor_orig = torch.cat( + [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0 + ) tensor_count = 1 split_dim = 0 - for dim, l in enumerate(layout.split('_')): + for dim, l in enumerate(layout.split("_")): if l.isdigit(): tensor_count = int(l) split_dim = dim break tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor] - tensors_orig = torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig] + tensors_orig = ( + torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig] + ) for j in range(tensor_count): if split_dim != 0: inp.append(tensors[j].squeeze(split_dim)) @@ -692,73 +834,77 @@ def _run_dot_product_attention( # Create ragged offsets for q/k/v seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None - qkv_group = ''.join([x for x in qkv_layout if x not in 'bst']) - if qkv_format == 'thd': + qkv_group = "".join([x for x in qkv_layout if x not in "bst"]) + if qkv_format == "thd": seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad - if qkv_group == 'hd_hd_hd': + if qkv_group == "hd_hd_hd": seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad - if qkv_group in ['3hd', 'h3d']: + if qkv_group in ["3hd", "h3d"]: seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad - if qkv_group in ['hd_2hd', 'hd_h2d']: + if qkv_group in ["hd_2hd", "hd_h2d"]: seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad # Create output gradient - qkv_format_kv = '_'.join(qkv_format) - qkv_format_kv = qkv_format_kv.replace('s', 'sq') - out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')] + qkv_format_kv = "_".join(qkv_format) + qkv_format_kv = qkv_format_kv.replace("s", "sq") + out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda") out_grad_orig = out_grad - if qkv_format == 'thd' and pad_between_seqs: - out_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) - if qkv_format_kv == 't_h_d': - for i in range(1, config.batch_size+1): - valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1]) - pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i]) - out_grad[pad_range[0]:pad_range[1]] = 0.0 - out_grad_orig = torch.cat([out_grad_orig, out_grad[valid_range[0]:valid_range[1]]], dim=0) + if qkv_format == "thd" and pad_between_seqs: + out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + if qkv_format_kv == "t_h_d": + for i in range(1, config.batch_size + 1): + valid_range = ( + cu_seqlens_q_after_pad[i - 1], + cu_seqlens_q_after_pad[i] - pad_len[i - 1], + ) + pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i]) + out_grad[pad_range[0] : pad_range[1]] = 0.0 + out_grad_orig = torch.cat( + [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0 + ) # Create bias - if config.attn_bias_type in ['no_bias', 'alibi']: + if config.attn_bias_type in ["no_bias", "alibi"]: bias = None - if config.attn_bias_type == 'post_scale_bias': - shape = '_'.join(config.bias_shape) - shape = shape.replace('_s_s', '_sq_skv') - tensor_shape = [dim_to_num[j] for j in shape.split('_')] + if config.attn_bias_type == "post_scale_bias": + shape = "_".join(config.bias_shape) + shape = shape.replace("_s_s", "_sq_skv") + tensor_shape = [dim_to_num[j] for j in shape.split("_")] bias = torch.randn(tensor_shape, dtype=dtype, device="cuda") - if config.bias_shape != '1hss': + if config.bias_shape != "1hss": bias.requires_grad = False # Create RNG _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER # Set up model - block = ( - DotProductAttention( - config.num_heads, - config.head_dim, - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type, - sequence_parallel=False, - tp_size=1, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, - tp_group=None, - layer_number=1, - attention_type=config.attn_type, - ).to(dtype=dtype, device="cuda") - ) + block = DotProductAttention( + config.num_heads, + config.head_dim, + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + attention_type=config.attn_type, + ).to(dtype=dtype, device="cuda") # Run a forward and backward pass if backend in ["FlashAttention", "UnfusedDotProductAttention"]: @@ -771,24 +917,28 @@ def _run_dot_product_attention( k = inp[1] v = inp[2] d_out = out_grad - out = block(q, k, v, - window_size=window_size, - attention_mask=attention_mask, - qkv_format=qkv_format, - max_seqlen_q=config.max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - seq_offsets_q=seq_offsets_q, - seq_offsets_k=seq_offsets_k, - seq_offsets_v=seq_offsets_v, - seq_offsets_o=seq_offsets_o, - attn_mask_type=config.attn_mask_type, - checkpoint_core_attention=ckpt_attn, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias, - alibi_slopes=alibi_slopes, - fast_zero_fill=True) + out = block( + q, + k, + v, + window_size=window_size, + attention_mask=attention_mask, + qkv_format=qkv_format, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + seq_offsets_q=seq_offsets_q, + seq_offsets_k=seq_offsets_k, + seq_offsets_v=seq_offsets_v, + seq_offsets_o=seq_offsets_o, + attn_mask_type=config.attn_mask_type, + checkpoint_core_attention=ckpt_attn, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + alibi_slopes=alibi_slopes, + fast_zero_fill=True, + ) if is_training: out.backward(d_out) @@ -798,18 +948,30 @@ def _run_dot_product_attention( else: return out, (None, None, None) if backend == "FusedAttention": - if qkv_format == 'thd' and pad_between_seqs: - out_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) - q_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) - k_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) - v_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) - for i in range(1, config.batch_size+1): - valid_range_q = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1]) - valid_range_kv = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1]) - out_orig = torch.cat([out_orig, out[valid_range_q[0]:valid_range_q[1]]], dim=0) - q_grad_orig = torch.cat([q_grad_orig, q.grad[valid_range_q[0]:valid_range_q[1]]], dim=0) - k_grad_orig = torch.cat([k_grad_orig, k.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0) - v_grad_orig = torch.cat([v_grad_orig, v.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0) + if qkv_format == "thd" and pad_between_seqs: + out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + for i in range(1, config.batch_size + 1): + valid_range_q = ( + cu_seqlens_q_after_pad[i - 1], + cu_seqlens_q_after_pad[i] - pad_len[i - 1], + ) + valid_range_kv = ( + cu_seqlens_kv_after_pad[i - 1], + cu_seqlens_kv_after_pad[i] - pad_len[i - 1], + ) + out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0) + q_grad_orig = torch.cat( + [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0 + ) + k_grad_orig = torch.cat( + [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 + ) + v_grad_orig = torch.cat( + [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 + ) if is_training: return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) else: @@ -823,18 +985,18 @@ def _run_dot_product_attention( model_configs_te_layer = { # test: b, h, hg, d, sq, skv, p, mask, bias - "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), - "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), - "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), - "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), + "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), + "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), + "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), + "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), + "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), + "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), } -@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model", model_configs_te_layer.keys()) @@ -842,7 +1004,9 @@ model_configs_te_layer = { @pytest.mark.parametrize("qkv_format", ["sbhd"]) @pytest.mark.parametrize("fused_qkv_params", [False]) @pytest.mark.parametrize("RoPE", [False]) -def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE): +def test_transformer_layer( + dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE +): """Test TransformerLayer module""" # Get configs @@ -916,7 +1080,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) -@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model", ["te_1_2", "te_2_0"]) @@ -926,22 +1090,24 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format): ckpt_attn = True fused_qkv_params = True RoPE = True - test_transformer_layer(dtype, model_configs, model, - ckpt_attn, qkv_format, fused_qkv_params, RoPE) + test_transformer_layer( + dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE + ) -@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"]) def test_te_layer_mqa_gqa(dtype, model_configs, model): """Test TransformerLayer module with MQA/GQA""" + def find_factors(x): - f = [] - for i in range(2, x + 1): - if x % i == 0: - f.append(i) - return f + f = [] + for i in range(2, x + 1): + if x % i == 0: + f.append(i) + return f ckpt_attn = True qkv_format = "bshd" @@ -951,21 +1117,22 @@ def test_te_layer_mqa_gqa(dtype, model_configs, model): num_querys_per_gqa_group = find_factors(config.num_heads) for num_q_per_gqa_group in num_querys_per_gqa_group: - config.num_gqa_groups=config.num_heads // num_q_per_gqa_group - test_transformer_layer(dtype, model_configs, model, - ckpt_attn, qkv_format, fused_qkv_params, RoPE) + config.num_gqa_groups = config.num_heads // num_q_per_gqa_group + test_transformer_layer( + dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE + ) def _run_transformer_layer( - dtype: torch.dtype, - config: ModelConfig, - backend: str, - ckpt_attn: bool, - qkv_format: str, - workspace_opt: bool, - fused_qkv_params: bool, - RoPE: bool, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + dtype: torch.dtype, + config: ModelConfig, + backend: str, + ckpt_attn: bool, + qkv_format: str, + workspace_opt: bool, + fused_qkv_params: bool, + RoPE: bool, +) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run TransformerLayer module with one forward pass and one backward pass""" # Set RNG and environment variables @@ -978,29 +1145,47 @@ def _run_transformer_layer( os.environ["NVTE_FUSED_ATTN"] = "1" # Create input tensor - inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size, - dtype=dtype, device="cuda", requires_grad = True) + inp = torch.randn( + config.max_seqlen_q, + config.batch_size, + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) # In case the format to be tested is batch-first, need to transpose the # input tensor. if qkv_format == "bshd": - inp = inp.transpose(0,1) + inp = inp.transpose(0, 1) # Create seqlens if "padding" in config.attn_mask_type: - seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size], - dtype=torch.int32, device="cuda") + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) else: - seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, - dtype=torch.int32, device="cuda") + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) # Create attention mask if padding attention_mask = None if "padding" in config.attn_mask_type: attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) for i in range(config.batch_size): - attention_mask_q = torch.cat([attention_mask_q, - torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i])) - .to(torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0) + attention_mask_q = torch.cat( + [ + attention_mask_q, + torch.Tensor( + [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i]) + ) + .to(torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) attention_mask = attention_mask_q.to(device="cuda") sigma = 0.02 @@ -1009,14 +1194,19 @@ def _run_transformer_layer( layer_number = 1 drop_path_rate = 0.0 - drop_path_rates = [ - rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] + drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] # Create bias bias = None - if config.attn_bias_type == 'post_scale_bias': - bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv, - dtype=dtype, device="cuda") + if config.attn_bias_type == "post_scale_bias": + bias = torch.randn( + 1, + config.num_heads, + config.max_seqlen_q, + config.max_seqlen_kv, + dtype=dtype, + device="cuda", + ) # Create RoPE rotary_pos_emb = None @@ -1025,58 +1215,56 @@ def _run_transformer_layer( rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") # Set up model - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_heads, - num_gqa_groups=config.num_gqa_groups, - layernorm_epsilon=1e-5, - hidden_dropout=0.0, - attention_dropout=config.dropout_p, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - kv_channels=config.head_dim, - self_attn_mask_type=config.attn_mask_type, - tp_group=None, - tp_size=1, - params_dtype=dtype, - get_rng_state_tracker=None, - fuse_wgrad_accumulation=False, - seq_length=config.max_seqlen_q, - micro_batch_size=config.batch_size, - sequence_parallel=False, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - layer_type="encoder", - drop_path_rate=drop_path_rates[layer_number - 1], - set_parallel_mode=True, - fuse_qkv_params=fused_qkv_params, - zero_centered_gamma=False, - qkv_weight_interleaved=False, - ub_tp_comm_overlap=False, - bias=True, - attn_input_format=qkv_format, - ) - .to(dtype=dtype, device="cuda") - ) + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layernorm_epsilon=1e-5, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim, + self_attn_mask_type=config.attn_mask_type, + tp_group=None, + tp_size=1, + params_dtype=dtype, + get_rng_state_tracker=None, + fuse_wgrad_accumulation=False, + seq_length=config.max_seqlen_q, + micro_batch_size=config.batch_size, + sequence_parallel=False, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + layer_type="encoder", + drop_path_rate=drop_path_rates[layer_number - 1], + set_parallel_mode=True, + fuse_qkv_params=fused_qkv_params, + zero_centered_gamma=False, + qkv_weight_interleaved=False, + ub_tp_comm_overlap=False, + bias=True, + attn_input_format=qkv_format, + ).to(dtype=dtype, device="cuda") # Create ALiBi slopes alibi_slopes = None if config.attn_bias_type == "alibi" and config.alibi_type == "custom": - alibi_slopes = torch.randn( - config.num_heads).abs().to(dtype=torch.float32, device="cuda") + alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda") # Run a forward and backward pass - out = block(inp, + out = block( + inp, attention_mask=attention_mask, self_attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, rotary_pos_emb=rotary_pos_emb, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + ) loss = out.sum() loss.backward() @@ -1085,23 +1273,24 @@ def _run_transformer_layer( model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_9" : ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_9": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), - "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), + "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), + "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] -qkv_layout_fp8_vs_f16 = ['sbh3d', 'bshd_bshd_bshd', 'sbhd_sbhd_sbhd'] -qkv_format_fp8_vs_f16 = ['bshd', 'sbhd'] +qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] +qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] + def _rmse(a, b): - return math.sqrt((torch.pow((a-b), 2)/a.numel()).sum()) + return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) -@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @@ -1118,58 +1307,78 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm) + dtype, config, True, qkv_format, input_layernorm + ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm) + dtype, config, False, qkv_format, input_layernorm + ) tols = dict(atol=5e-1, rtol=5e-1) rmse_tol = 0.1 fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16) - fwd_range = max(fused_attn_fwd_fp8.max().item(), - fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), - fused_attn_fwd_f16.min().item()) - - logging.debug('========== {:^25s} =========='.format('forward output')) - logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( - fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) - logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( - fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item())) - logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse)) + fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min( + fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item() + ) + + logging.debug("========== {:^25s} ==========".format("forward output")) + logging.debug( + "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( + fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() + ) + ) + logging.debug( + "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( + fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() + ) + ) + logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse)) try: torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) except Exception as e: logging.debug(e) - assert(fwd_rmse < rmse_tol * fwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) + assert ( + fwd_rmse < rmse_tol * fwd_range + ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range + ) for i in range(len(param_names[:1])): bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) - bwd_range = max(fused_attn_bwd_fp8[i].max().item(), - fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(), - fused_attn_bwd_f16[i].min().item()) - - logging.debug('========== {:^25s} =========='.format(param_names[i])) - logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i, - fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item())) - logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i, - fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item())) - logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse)) + bwd_range = max( + fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item() + ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item()) + + logging.debug("========== {:^25s} ==========".format(param_names[i])) + logging.debug( + "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format( + i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item() + ) + ) + logging.debug( + "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format( + i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item() + ) + ) + logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse)) try: torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) except Exception as e: logging.debug(e) - assert(bwd_rmse < rmse_tol * bwd_range - ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) + assert ( + bwd_rmse < rmse_tol * bwd_range + ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range + ) + def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER @@ -1184,7 +1393,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ) with fp8_model_init(enabled=fp8_mha): - mha = (MultiheadAttention( + mha = MultiheadAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, kv_channels=config.head_dim, @@ -1199,34 +1408,35 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): attention_type="self", qkv_weight_interleaved=True, qkv_format=qkv_format, - ).to(dtype=dtype, device="cuda") - ) + ).to(dtype=dtype, device="cuda") - seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, - dtype=torch.int32, device="cuda") - seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, - dtype=torch.int32, device="cuda") + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) dim_to_num = { - 'b' : config.batch_size, - 'sq' : config.max_seqlen_q, - 'skv': config.max_seqlen_kv, - 'h' : config.num_heads, - 'hg' : config.num_gqa_groups, - 'd' : config.head_dim, - 't' : cu_seqlens_q[-1], - 'tg' : cu_seqlens_kv[-1], - '3' : 3, - '2' : 2, - '1' : 1, - } - layout = '_'.join(qkv_format) - layout = layout.replace('s', 'sq') - tensor_shape = [dim_to_num[j] for j in layout.split('_')] + "b": config.batch_size, + "sq": config.max_seqlen_q, + "skv": config.max_seqlen_kv, + "h": config.num_heads, + "hg": config.num_gqa_groups, + "d": config.head_dim, + "t": cu_seqlens_q[-1], + "tg": cu_seqlens_kv[-1], + "3": 3, + "2": 2, + "1": 1, + } + layout = "_".join(qkv_format) + layout = layout.replace("s", "sq") + tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda") hidden_states = tensor.view(*tensor.shape[:-2], -1) hidden_states.requires_grad = True @@ -1234,27 +1444,28 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): out_grad = tensor.view(*tensor.shape[:-2], -1) with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe): - out = mha(hidden_states, + out = mha( + hidden_states, attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, is_first_microbatch=None, - ) + ) out.backward(out_grad) param_names = [] - param_names.append('hidden_states.grad') + param_names.append("hidden_states.grad") params = [] params.append(hidden_states) for name, param in mha.named_parameters(): if param.requires_grad: - param_names.append(name+'.grad') + param_names.append(name + ".grad") params.append(param) return out, param_names, tuple(x.grad for x in params) -@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @@ -1264,62 +1475,75 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): config = model_configs_fp8_vs_f16[model] - if (config.num_heads != config.num_gqa_groups and '3' in qkv_layout): - pytest.skip("qkv_layout not applicable for MQA/GQA"); + if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: + pytest.skip("qkv_layout not applicable for MQA/GQA") os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") - fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout) + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout) logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout) + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout) tols = dict(atol=5e-1, rtol=5e-2) rmse_tol = 0.1 - bwd_names = ['dq', 'dk', 'dv'] + bwd_names = ["dq", "dk", "dv"] fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16) - fwd_range = max(fused_attn_fwd_fp8.max().item(), - fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), - fused_attn_fwd_f16.min().item()) - - logging.debug('========== {:^25s} =========='.format('forward output')) - logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( - fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) - logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( - fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item())) - logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse)) + fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min( + fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item() + ) + + logging.debug("========== {:^25s} ==========".format("forward output")) + logging.debug( + "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( + fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() + ) + ) + logging.debug( + "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( + fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() + ) + ) + logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse)) try: torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) except Exception as e: logging.debug(e) - assert(fwd_rmse < rmse_tol * fwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) - for i,_ in enumerate(fused_attn_bwd_f16): + assert ( + fwd_rmse < rmse_tol * fwd_range + ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range + ) + for i, _ in enumerate(fused_attn_bwd_f16): bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) - bwd_range = max(fused_attn_bwd_fp8[i].max().item(), - fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(), - fused_attn_bwd_f16[i].min().item()) - - logging.debug('========== {:^25s} =========='.format(bwd_names[i])) - logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i, - fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item())) - logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i, - fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item())) - logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse)) + bwd_range = max( + fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item() + ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item()) + + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + logging.debug( + "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format( + i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item() + ) + ) + logging.debug( + "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format( + i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item() + ) + ) + logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse)) try: torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) except Exception as e: logging.debug(e) - assert(bwd_rmse < rmse_tol * bwd_range - ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) + assert ( + bwd_rmse < rmse_tol * bwd_range + ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range + ) def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): @@ -1327,6 +1551,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER @@ -1339,60 +1564,60 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): fp8_dpa=fp8_dpa, ) - qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) with fp8_model_init(enabled=fp8_dpa): - dpa = ( - DotProductAttention( - config.num_heads, - config.head_dim, - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - sequence_parallel=False, - tp_size=1, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, - tp_group=None, - layer_number=1, - attention_type="self", - qkv_format=qkv_format, - ).to(dtype=dtype, device="cuda") - ) + dpa = DotProductAttention( + config.num_heads, + config.head_dim, + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + attention_type="self", + qkv_format=qkv_format, + ).to(dtype=dtype, device="cuda") - seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, - dtype=torch.int32, device="cuda") - seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, - dtype=torch.int32, device="cuda") + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) dim_to_num = { - 'b' : config.batch_size, - 'sq' : config.max_seqlen_q, - 'skv': config.max_seqlen_kv, - 'h' : config.num_heads, - 'hg' : config.num_gqa_groups, - 'd' : config.head_dim, - 't' : cu_seqlens_q[-1], - 'tg' : cu_seqlens_kv[-1], - '3' : 3, - '2' : 2, - '1' : 1, - } + "b": config.batch_size, + "sq": config.max_seqlen_q, + "skv": config.max_seqlen_kv, + "h": config.num_heads, + "hg": config.num_gqa_groups, + "d": config.head_dim, + "t": cu_seqlens_q[-1], + "tg": cu_seqlens_kv[-1], + "3": 3, + "2": 2, + "1": 1, + } inp = [] - for i,layout in enumerate(qkv_layout.split('_')): - layout = '_'.join(layout) + for i, layout in enumerate(qkv_layout.split("_")): + layout = "_".join(layout) if i == 0: - layout = layout.replace('s', 'sq') + layout = layout.replace("s", "sq") else: - layout = layout.replace('s', 'skv') - layout = layout.replace('h', 'hg') - layout = layout.replace('t', 'tg') - tensor_shape = [dim_to_num[j] for j in layout.split('_')] + layout = layout.replace("s", "skv") + layout = layout.replace("h", "hg") + layout = layout.replace("t", "tg") + tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor_count = 1 split_dim = 0 - for dim, l in enumerate(layout.split('_')): + for dim, l in enumerate(layout.split("_")): if l.isdigit(): tensor_count = int(l) split_dim = dim @@ -1406,14 +1631,17 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): for i in range(3): inp[i].requires_grad = True - qkv_format_kv = '_'.join(qkv_format) - qkv_format_kv = qkv_format_kv.replace('s', 'sq') - out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')] + qkv_format_kv = "_".join(qkv_format) + qkv_format_kv = qkv_format_kv.replace("s", "sq") + out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe): - out = dpa(inp[0], inp[1], inp[2], + out = dpa( + inp[0], + inp[1], + inp[2], qkv_format=qkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, @@ -1423,7 +1651,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, is_first_microbatch=True, - ) + ) out.backward(out_grad) return out, (inp[0].grad, inp[1].grad, inp[2].grad) @@ -1431,22 +1659,22 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): model_configs_fp8 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"), - "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), - "fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"), + "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), + "fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"), - "fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"), - "fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"), + "fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"), + "fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), } param_types_fp8 = [torch.float16, torch.bfloat16] -cudnn_frontend_version = int(os.getenv('NVTE_FUSED_ATTN_FE_VER','1')) -models_v0 = ['fp8_1', 'fp8_2', 'fp8_5', 'fp8_6'] -models_v1 = ['fp8_3', 'fp8_4', 'fp8_7', 'fp8_8'] +cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1")) +models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"] +models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"] -@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.parametrize("dtype", param_types_fp8) @@ -1460,50 +1688,62 @@ def test_custom_mha_fp8_vs_f16(dtype, model): config = model_configs_fp8[model] - fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8( - dtype, config, "FusedAttention") - unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16( - dtype, config, "UnfusedAttention") + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") + unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") tols = dict(atol=5e-1, rtol=5e-1) rmse_tol = 0.1 fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16) - fwd_range = max(fused_attn_fwd_fp8.max().item(), - unfused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), - unfused_attn_fwd_f16.min().item()) + fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min( + fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item() + ) bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16) - bwd_range = max(fused_attn_bwd_fp8.max().item(), - unfused_attn_bwd_f16.max().item()) - min(fused_attn_bwd_fp8.min().item(), - unfused_attn_bwd_f16.min().item()) - - logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( - fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) - logging.debug('unfused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( - unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item())) - logging.debug('fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}'.format( - fwd_rmse)) + bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min( + fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item() + ) + + logging.debug( + "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( + fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() + ) + ) + logging.debug( + "unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format( + unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item() + ) + ) + logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse)) try: torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols) except Exception as e: logging.debug(e) - logging.debug('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format( - fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item())) - logging.debug('unfused_attn_bwd_f16 min {:.6f} max {:.6f}'.format( - unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item())) - logging.debug('fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}'.format( - bwd_rmse)) + logging.debug( + "fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format( + fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item() + ) + ) + logging.debug( + "unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format( + unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item() + ) + ) + logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse)) try: torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols) except Exception as e: logging.debug(e) - assert(fwd_rmse < rmse_tol * fwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) - assert(bwd_rmse < rmse_tol * bwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) + assert ( + fwd_rmse < rmse_tol * fwd_range + ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range + ) + assert ( + bwd_rmse < rmse_tol * bwd_range + ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range + ) def _run_custom_mha_fp8(dtype, config, backend): @@ -1517,18 +1757,25 @@ def _run_custom_mha_fp8(dtype, config, backend): if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - inp = 0.0001 * torch.randint(-100, 100, - (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim), - dtype=dtype, device="cuda", requires_grad=True) - seqlens = torch.full([config.batch_size], config.max_seqlen_q, - dtype=torch.int32, device="cuda") + inp = 0.0001 * torch.randint( + -100, + 100, + (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda") cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) out_grad = 0.01 * torch.randn( - config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim, - dtype=dtype, device="cuda") - torch.save(out_grad, 'out_grad.pt') + config.batch_size * config.max_seqlen_q, + config.num_heads * config.head_dim, + dtype=dtype, + device="cuda", + ) + torch.save(out_grad, "out_grad.pt") fp8_recipe = recipe.DelayedScaling( margin=0, @@ -1543,10 +1790,13 @@ def _run_custom_mha_fp8(dtype, config, backend): out.backward(out_grad) out = torch.load("out.pt") - dqkv = torch.load('dqkv.pt') - return (out.view(config.batch_size, config.max_seqlen_q, -1), - dqkv.view(config.batch_size, config.max_seqlen_q, 3, - config.num_heads, config.head_dim).contiguous()) + dqkv = torch.load("dqkv.pt") + return ( + out.view(config.batch_size, config.max_seqlen_q, -1), + dqkv.view( + config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim + ).contiguous(), + ) def _run_ref_mha_f16(dtype, config, backend): @@ -1560,13 +1810,14 @@ def _run_ref_mha_f16(dtype, config, backend): if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - inp = torch.load('qkv.pt').to(device="cuda") + inp = torch.load("qkv.pt").to(device="cuda") inp.requires_grad = True seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda") cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) - out_grad = torch.load('out_grad.pt').to(device="cuda").view( - config.batch_size, config.max_seqlen_q, -1) + out_grad = ( + torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1) + ) _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1582,24 +1833,22 @@ def _run_ref_mha_f16(dtype, config, backend): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - block = ( - DotProductAttention( - config.num_heads, - config.head_dim, - attention_dropout=config.dropout_p, - sequence_parallel=False, - tp_size=1, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, - tp_group=None, - layer_number=1, - attention_type="self", - qkv_format="bshd", - ).to(dtype=dtype, device="cuda") - ) - - q = inp[:,:,0,:,:] - k = inp[:,:,1,:,:] - v = inp[:,:,2,:,:] + block = DotProductAttention( + config.num_heads, + config.head_dim, + attention_dropout=config.dropout_p, + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + attention_type="self", + qkv_format="bshd", + ).to(dtype=dtype, device="cuda") + + q = inp[:, :, 0, :, :] + k = inp[:, :, 1, :, :] + v = inp[:, :, 2, :, :] out = block(q, k, v, attn_mask_type=config.attn_mask_type) out.backward(out_grad) @@ -1611,12 +1860,12 @@ _2X_ACC_FPROP = False _2X_ACC_DGRAD = False _2X_ACC_WGRAD = False -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT +META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +META_O = tex.FP8FwdTensors.GEMM2_INPUT +META_DO = tex.FP8BwdTensors.GRAD_INPUT2 +META_S = tex.FP8FwdTensors.GEMM3_OUTPUT +META_DP = tex.FP8BwdTensors.GRAD_INPUT3 class _custom_mha_fp8(torch.autograd.Function): @@ -1682,46 +1931,57 @@ class _custom_mha_fp8(torch.autograd.Function): D_dtype=fp8_dtype_forward, ) qkv = qkv.view(-1, 3, h, d) - qkv_fp16 = ext.cast_from_fp8(qkv, fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, - tex.DType.kFloat16).view(b, max_s, 3, h, d).contiguous() - torch.save(qkv_fp16, 'qkv.pt') + qkv_fp16 = ( + ext.cast_from_fp8( + qkv, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, tex.DType.kFloat16 + ) + .view(b, max_s, 3, h, d) + .contiguous() + ) + torch.save(qkv_fp16, "qkv.pt") if cudnn_frontend_version == 1: - qkv = qkv.view(b, max_s, 3, h, d) # bs3hd + qkv = qkv.view(b, max_s, 3, h, d) # bs3hd # FMHA out, aux_ctx_tensors, *rest = fused_attn_fwd( - is_training, - max_s, - max_s, - cu_seqlens, - cu_seqlens, - qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:], - qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:], - qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:], - fp8_dtype_forward, - FusedAttnBackend["FP8"], - None, None, None, None, None, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], - attn_scale=None, - dropout=p_dropout, - fast_zero_fill=fast_zero_fill, - qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", - attn_bias_type="no_bias", - attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", - rng_gen=None, - ) + is_training, + max_s, + max_s, + cu_seqlens, + cu_seqlens, + qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :], + qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :], + qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :], + fp8_dtype_forward, + FusedAttnBackend["FP8"], + None, + None, + None, + None, + None, + fp8_meta["scaling_fwd"].scale_inv[META_QKV], + fp8_meta["scaling_fwd"].scale_inv[META_S], + fp8_meta["scaling_fwd"].scale[META_S], + fp8_meta["scaling_fwd"].scale[META_O], + fp8_meta["scaling_fwd"].amax_history[0][META_S], + fp8_meta["scaling_fwd"].amax_history[0][META_O], + attn_scale=None, + dropout=p_dropout, + fast_zero_fill=fast_zero_fill, + qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", + attn_bias_type="no_bias", + attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", + rng_gen=None, + ) M, ZInv, philox_unpacked = aux_ctx_tensors ctx.save_for_backward( - inp_t_fp8, qkv_weight_t_fp8, workspace, - qkv, out, + inp_t_fp8, + qkv_weight_t_fp8, + workspace, + qkv, + out, fp8_meta["scaling_fwd"].scale, fp8_meta["scaling_fwd"].scale_inv, ) @@ -1736,82 +1996,84 @@ class _custom_mha_fp8(torch.autograd.Function): ctx.mask_type = mask_type ctx.dtype = inp.dtype - out = out.view(-1, in_features) # (bs)(hd) - out_fp16 = ext.cast_from_fp8(out, fp8_meta["scaling_fwd"], - META_O, fp8_dtype_forward, tex.DType.kFloat16) - torch.save(out_fp16, 'out.pt') # (bs)(hd) + out = out.view(-1, in_features) # (bs)(hd) + out_fp16 = ext.cast_from_fp8( + out, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward, tex.DType.kFloat16 + ) + torch.save(out_fp16, "out.pt") # (bs)(hd) return out_fp16 - @staticmethod - def backward( - ctx, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_DPA"): ( inp_t_fp8, qkv_weight_t_fp8, workspace, - qkv, out, + qkv, + out, fwd_scales, fwd_scale_inverses, ) = ctx.saved_tensors - fp8_dtype_forward = fp8.get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=True - ) - fp8_dtype_backward = fp8.get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) + fp8_dtype_forward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) proj_dgrad = ext.cast_to_fp8( grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) # (bs)(hd) + ) # (bs)(hd) dq, dk, dv, *rest = fused_attn_bwd( - ctx.max_s, - ctx.max_s, - ctx.cu_seqlens, - ctx.cu_seqlens, - qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:], - qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:], - qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:], - out, - proj_dgrad.view_as(out), - fp8_dtype_forward, - fp8_dtype_backward, - ctx.aux_ctx_tensors, - FusedAttnBackend["FP8"], - None, None, None, None, - fwd_scale_inverses[META_QKV], # d_scale_qkv, - fwd_scale_inverses[META_S], # d_scale_s, - fwd_scale_inverses[META_O], # d_scale_o, - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp - ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv - attn_scale=None, - dropout=ctx.p_dropout, - fast_zero_fill=ctx.fast_zero_fill, - qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", - attn_bias_type="no_bias", - attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", - ) + ctx.max_s, + ctx.max_s, + ctx.cu_seqlens, + ctx.cu_seqlens, + qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :], + qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :], + qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :], + out, + proj_dgrad.view_as(out), + fp8_dtype_forward, + fp8_dtype_backward, + ctx.aux_ctx_tensors, + FusedAttnBackend["FP8"], + None, + None, + None, + None, + fwd_scale_inverses[META_QKV], # d_scale_qkv, + fwd_scale_inverses[META_S], # d_scale_s, + fwd_scale_inverses[META_O], # d_scale_o, + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp + ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + attn_scale=None, + dropout=ctx.p_dropout, + fast_zero_fill=ctx.fast_zero_fill, + qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", + attn_bias_type="no_bias", + attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", + ) dim = 2 if cudnn_frontend_version == 1 else 1 dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype) dqkv_shape = list(dq.shape) dqkv_shape.insert(dim, 3) dqkv_stride = list(dq.stride()) - dqkv_stride.insert(dim, int(dqkv_stride[-3]/3)) - dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd + dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3)) + dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd - dqkv_c = dqkv.view(-1, 3*ctx.hidden_size) - dqkv_c_fp16 = ext.cast_from_fp8(dqkv_c, - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, tex.DType.kFloat16) - torch.save(dqkv_c_fp16, 'dqkv.pt') + dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size) + dqkv_c_fp16 = ext.cast_from_fp8( + dqkv_c, + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + tex.DType.kFloat16, + ) + torch.save(dqkv_c_fp16, "dqkv.pt") qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused( dqkv_c, @@ -1850,7 +2112,8 @@ class _custom_mha_fp8(torch.autograd.Function): use_split_accumulator=_2X_ACC_WGRAD, ) - return (qkv_dgrad, + return ( + qkv_dgrad, qkv_wgrad, qkv_bgrad, None, @@ -1862,14 +2125,12 @@ class _custom_mha_fp8(torch.autograd.Function): None, None, None, - None) + None, + ) class Custom_MHA_FP8(TransformerEngineBaseModule): - def __init__( - self, - config, - params_dtype: torch.dtype = torch.float32): + def __init__(self, config, params_dtype: torch.dtype = torch.float32): super().__init__() self.p_dropout = config.dropout_p self.h = config.num_heads @@ -1901,8 +2162,10 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ) def forward( - self, inp: torch.Tensor, - cu_seqlens, max_s, + self, + inp: torch.Tensor, + cu_seqlens, + max_s, ) -> torch.Tensor: with self.prepare_forward(inp, None, num_gemms=3) as inp: out = _custom_mha_fp8.apply( @@ -1917,5 +2180,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): self.fp8_meta, self.workspace, self.training, - self.mask_type) + self.mask_type, + ) return out diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 928ca27578beb3b0c170ac49fe1c270972e92f12..754416c837029ce774df2b9071bfcfbb06a5cec7 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -14,12 +14,13 @@ from transformer_engine.pytorch.utils import get_device_compute_capability model_configs_flash_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA - "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA - "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA + "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA + "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA } + def get_bash_arguments(**kwargs): args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"] te_path = os.getenv("TE_PATH", "/opt/transformerengine") @@ -29,46 +30,43 @@ def get_bash_arguments(**kwargs): args.append(f"{k}={v}") return args + @pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ['bf16', 'fp16']) +@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) -@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd']) +@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) def test_cp_with_flash_attention(dtype, model, qkv_format): subprocess.run( get_bash_arguments( - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend='FlashAttention' + dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention" ), - check=True + check=True, ) + model_configs_fused_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA - "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA - "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA - "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA - "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA - "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA - "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA + "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA + "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA + "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA + "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA + "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA } -@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.") + +@pytest.mark.skipif(_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ['bf16', 'fp16']) +@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) -@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd']) +@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) def test_cp_with_fused_attention(dtype, model, qkv_format): subprocess.run( get_bash_arguments( - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend='FusedAttention' + dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention" ), - check=True + check=True, ) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 71023c32f9ab042a76761c4c5615b20927d5e687..8d3a9dca4f228a64e3b653cd7233f3ff9aa59c23 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -8,8 +8,15 @@ import pytest import torch from transformer_engine.pytorch import ( - DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, make_graphed_callables, - MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init, + DotProductAttention, + LayerNormLinear, + LayerNormMLP, + Linear, + make_graphed_callables, + MultiheadAttention, + TransformerLayer, + fp8_autocast, + fp8_model_init, ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible @@ -26,15 +33,18 @@ torch.cuda.manual_seed(seed) _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() + @dataclass class ModelConfig: """Data tensor dimensions within Transformer model""" + sequence_length: int batch_size: int hidden_size: int num_heads: int kv_channels: int + model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] @@ -66,7 +76,9 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) for i, (t1, t2) in enumerate(zip(l1, l2)): if not torch.equal(t1, t2): failed = True - failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" + failed_tensors += ( + f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" + ) assert not failed, "Output mismatches in:\n" + failed_tensors @@ -157,41 +169,51 @@ def _test_cuda_graphs( with fp8_model_init(enabled=fp8_params): # Create modules. if module == "transformer": - modules = [TransformerLayer( - config.hidden_size, - config.hidden_size, - config.num_heads, - hidden_dropout=0.0, - attention_dropout=0.0, - fuse_qkv_params=True, - params_dtype=dtype, - ) for _ in range(num_layers)] + modules = [ + TransformerLayer( + config.hidden_size, + config.hidden_size, + config.num_heads, + hidden_dropout=0.0, + attention_dropout=0.0, + fuse_qkv_params=True, + params_dtype=dtype, + ) + for _ in range(num_layers) + ] elif module == "layernorm_mlp": - modules = [LayerNormMLP( - config.hidden_size, config.hidden_size, params_dtype=dtype - ) for _ in range(num_layers)] + modules = [ + LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype) + for _ in range(num_layers) + ] elif module == "layernorm_linear": - modules = [LayerNormLinear( - config.hidden_size, config.hidden_size, params_dtype=dtype - ) for _ in range(num_layers)] + modules = [ + LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype) + for _ in range(num_layers) + ] elif module == "mha": - modules = [MultiheadAttention( - config.hidden_size, - config.num_heads, - attention_dropout=0.0, - params_dtype=dtype, - fuse_qkv_params=True, - ) for _ in range(num_layers)] + modules = [ + MultiheadAttention( + config.hidden_size, + config.num_heads, + attention_dropout=0.0, + params_dtype=dtype, + fuse_qkv_params=True, + ) + for _ in range(num_layers) + ] elif dpa: assert config.hidden_size % config.num_heads == 0, "Err." assert num_layers == 1, "Err." - modules = [DotProductAttention( - config.num_heads, config.kv_channels, attention_dropout=0.0 - ) for _ in range(num_layers)] + modules = [ + DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0) + for _ in range(num_layers) + ] else: - modules = [Linear( - config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype - ) for _ in range(num_layers)] + modules = [ + Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype) + for _ in range(num_layers) + ] # Initialize gradient buffers. for module in modules: @@ -238,7 +260,7 @@ def _test_cuda_graphs( with fp8_autocast(enabled=fp8): kwargs = {} if fp8_weight_caching: - kwargs["is_first_microbatch"] = (grad_accumulation_step == 0) + kwargs["is_first_microbatch"] = grad_accumulation_step == 0 output = model(*inputs, **kwargs) output.backward(grad_output) if not dpa: diff --git a/tests/pytorch/test_deferred_init.py b/tests/pytorch/test_deferred_init.py index cbc761a27caa084a06b238ec11798d95969ec917..0469a01c5f109ddf6fd4a2cdc8bd93340418f709 100644 --- a/tests/pytorch/test_deferred_init.py +++ b/tests/pytorch/test_deferred_init.py @@ -27,33 +27,31 @@ num_heads = 16 head_dim = 64 dtype = torch.bfloat16 + class TestDeferredInit: @staticmethod def get_module_args(module): hidden_size = num_heads * head_dim args = (hidden_size,) - kwargs = { - 'params_dtype': dtype, - 'device': 'meta' - } + kwargs = {"params_dtype": dtype, "device": "meta"} if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]: ffn_hidden_size = 2 * hidden_size - args += (ffn_hidden_size, ) - kwargs['bias'] = True + args += (ffn_hidden_size,) + kwargs["bias"] = True if module == te.LayerNormMLP: - kwargs['seq_length'] = seq_length + kwargs["seq_length"] = seq_length elif module == te.MultiheadAttention: - args += (num_heads, ) - kwargs['fuse_qkv_params'] = True + args += (num_heads,) + kwargs["fuse_qkv_params"] = True elif module == te.TransformerLayer: args += (3 * hidden_size, num_heads) - kwargs['fuse_qkv_params'] = True - kwargs['seq_length'] = seq_length + kwargs["fuse_qkv_params"] = True + kwargs["seq_length"] = seq_length return args, kwargs - @pytest.mark.parametrize("module_type", _core_modules+_composed_modules) + @pytest.mark.parametrize("module_type", _core_modules + _composed_modules) def test_zero_memory_init( self, module_type: torch.nn.Module, diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index cead7767bb2b4d9d21aa05fe712635c1d43d4383..0ea03197714843e78dc03e4a48d64644fc408d36 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -26,6 +26,7 @@ _tols: Dict[tex.DType, Dict[str, float]] = { tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 } + def _to_list(x: Union[Iterable, Any]) -> List: """Convert to list if iterable, otherwise put in singleton list""" if isinstance(x, Iterable): @@ -33,12 +34,14 @@ def _to_list(x: Union[Iterable, Any]) -> List: else: return [x] + # Types that can be interpreted as tensor dims DimsType = Union[Iterable[int], int] # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFloat8Tensor: @@ -108,7 +111,7 @@ class TestFloat8Tensor: def test_quantize_dequantize_scales(self, scale: float) -> None: self._test_quantize_dequantize(scale=scale) - @pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]]) + @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]]) def test_quantize_dequantize_dims(self, dims: DimsType) -> None: self._test_quantize_dequantize(dims=dims) @@ -310,7 +313,7 @@ class TestFloat8Tensor: def test_serialization( self, - dims: DimsType = [2,3,5], + dims: DimsType = [2, 3, 5], fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, scale: float = 0.5, dtype: torch.dtype = torch.float32, diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 63ef4bdc0b236ccf6126f35fb845e901ec471e09..8a50648391b3cfde2daad72bbbe19e6ccb480811 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -117,9 +117,7 @@ class TestFusedAdam(TestFusedOptimizer): tensors = [] for size in sizes: tensors.append(torch.rand(size, dtype=torch.float, device="cuda")) - ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( - tensors, self.options - ) + ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, self.options) for i in range(self.iters): self.gen_grad(ref_param, tst_param) @@ -139,9 +137,7 @@ class TestFusedAdam(TestFusedOptimizer): } tensor = torch.rand(nelem, dtype=torch.float, device="cuda") - ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( - [tensor], adam_option - ) + ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option) for i in range(self.iters): self.gen_grad(ref_param, tst_param) @@ -161,9 +157,7 @@ class TestFusedAdam(TestFusedOptimizer): } tensor = torch.rand(nelem, dtype=torch.float, device="cuda") - ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( - [tensor], adam_option - ) + ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option) # Add an empty param group which may occur for pipeline parallel p-tuning tst_optim.add_param_group({"params": []}) @@ -175,10 +169,11 @@ class TestFusedAdam(TestFusedOptimizer): torch.testing.assert_close(ref_param, tst_param) + class TestFusedSGD(TestFusedOptimizer): def __init__(self, *args, **kwargs): super(TestFusedSGD, self).__init__(*args, **kwargs) - self.options = {"lr": .25, "momentum": .125} + self.options = {"lr": 0.25, "momentum": 0.125} self.ref_optim = torch.optim.SGD self.fused_optim = te.optimizers.FusedSGD @@ -188,7 +183,7 @@ class TestFusedSGD(TestFusedOptimizer): def test_half(self): self.gen_single_type_test(param_type=torch.float16) - @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") + @unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required") def test_multi_device(self): devices = ("cuda:0", "cuda:1") for current_dev, tensor_dev in product(devices, devices): @@ -452,8 +447,8 @@ class AdamTest(unittest.TestCase): @largeTensorTest("60GB", "cuda") def testLargeTensor(self): - t = torch.zeros(2359332864, dtype=torch.half, device='cuda') - t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda') + t = torch.zeros(2359332864, dtype=torch.half, device="cuda") + t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda") grad = torch.randn_like(t) t.grad = grad t2.grad = grad diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 0fb7597246239e6bbcd98db88d1db92b05a81709..d6ba66cbbcc20578018f7b6fb2323eaa54362613 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -26,10 +26,7 @@ def apply_rotary_pos_emb_thd( """ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return torch.cat( - [ - apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) - for x in torch.split(t, seqlens) - ] + [apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)] ).squeeze(1) @@ -45,6 +42,7 @@ def get_tol(dtype: torch.dtype) -> Dict: def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: return output.sum() * 2 + # Gradient is a full tensor def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: t = torch.ones_like(output) @@ -86,9 +84,7 @@ def test_fused_rope( emb = rotary_pos_emb(seq_length) # unfused - output_unfused = apply_rotary_pos_emb( - t, emb, tensor_format=tensor_format, fused=False - ) + output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False) loss_unfused = loss_func(output_unfused) loss_unfused.backward() grad_unfused = t.grad.detach().clone() diff --git a/tests/pytorch/test_gqa.py b/tests/pytorch/test_gqa.py index 91b29c6e266f675756605b8ee9d0c9d573de026d..9f9098891f8dc40924bc2ac938f0c1dbab4aa639 100644 --- a/tests/pytorch/test_gqa.py +++ b/tests/pytorch/test_gqa.py @@ -13,23 +13,16 @@ num_heads = 16 head_dim = 64 dtype = torch.bfloat16 num_attn_head = 16 -ffn_hidden_size=1024 +ffn_hidden_size = 1024 + @pytest.mark.parametrize("kv_channels", [128, 256]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("num_gqa_groups", [1, 2, 4, 8, 16]) -def test_gqa( - kv_channels, - hidden_size, - num_gqa_groups -) -> None: - +def test_gqa(kv_channels, hidden_size, num_gqa_groups) -> None: + model = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_attn_head, - num_gqa_groups, - kv_channels=kv_channels + hidden_size, ffn_hidden_size, num_attn_head, num_gqa_groups, kv_channels=kv_channels ) # Run forward pass @@ -42,10 +35,9 @@ def test_gqa( assert model.self_attention.layernorm_qkv.query_weight.shape[0] == kv_channels * num_attn_head assert model.self_attention.layernorm_qkv.query_weight.shape[1] == hidden_size - + assert model.self_attention.layernorm_qkv.value_weight.shape[0] == kv_channels * num_gqa_groups assert model.self_attention.layernorm_qkv.value_weight.shape[1] == hidden_size - + assert model.self_attention.proj.weight.shape[0] == hidden_size assert model.self_attention.proj.weight.shape[1] == kv_channels * num_attn_head - diff --git a/tests/pytorch/test_jit.py b/tests/pytorch/test_jit.py index 24288b5ff485aab9a7a88ee837419e2e185f9308..7d69e0371272c4ee6bf64c17211bde428c580012 100644 --- a/tests/pytorch/test_jit.py +++ b/tests/pytorch/test_jit.py @@ -11,11 +11,11 @@ import transformer_engine.pytorch as te # Model names for test_torch_dynamo _model_factory = { - "Linear": [(lambda: te.Linear(16, 16)), [16, 16]], - "LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]], - "LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]], - "LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]], - "TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]], + "Linear": [(lambda: te.Linear(16, 16)), [16, 16]], + "LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]], + "LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]], + "LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]], + "TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]], } @@ -31,11 +31,11 @@ def test_torch_dynamo(model_name: str): # Helper function to construct tensor with default options def make_tensor( - dims: Tuple[int], - dtype: torch.dtype = torch.float32, - device: torch.device = "cuda", - requires_grad: bool = True, - **kwargs, + dims: Tuple[int], + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + requires_grad: bool = True, + **kwargs, ): return torch.zeros( dims, diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index 5b528b67274ee285348c78f579d1bb4902154007..216b200e092945c6a7534c4958bcbd26c3ca8697 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -28,9 +28,7 @@ appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply @pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("inplace", [False, True]) -def test_multi_tensor_scale( - input_size_pair, applier, repeat, in_type, out_type, inplace -): +def test_multi_tensor_scale(input_size_pair, applier, repeat, in_type, out_type, inplace): if inplace is True and (out_type is not in_type): pytest.skip("inplace=True and out_type != in_type is not supported.") elif (in_type == torch.float16 and out_type == torch.bfloat16) or ( @@ -154,9 +152,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens in_list += [a.clone().to(in_type), b.clone().to(in_type)] if per_tensor: - norm, norm_per_tensor = applier( - tex.multi_tensor_l2norm, overflow_buf, [in_list], True - ) + norm, norm_per_tensor = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True) normab = torch.cat((a.norm().view(1), b.norm().view(1))) norm_per_tensor = norm_per_tensor.view(-1, 2) else: @@ -168,9 +164,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens torch.testing.assert_close(norm, reference.broadcast_to(norm.shape)) if per_tensor: - torch.testing.assert_close( - norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape) - ) + torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)) assert overflow_buf.item() == 0 @@ -179,9 +173,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens @pytest.mark.parametrize("repeat", [1, 55]) @pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("per_tensor", [False, True]) -def test_multi_tensor_unscale_l2norm( - input_size_pair, applier, repeat, in_type, per_tensor -): +def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, per_tensor): sizea, sizeb = input_size_pair device = torch.device("cuda") val = 4.0 @@ -205,9 +197,7 @@ def test_multi_tensor_unscale_l2norm( inv_scale_cuda, True, ) - normab = torch.cat( - ((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1)) - ) + normab = torch.cat(((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1))) norm_per_tensor = norm_per_tensor.view(-1, 2) else: norm, _ = applier( @@ -224,7 +214,5 @@ def test_multi_tensor_unscale_l2norm( torch.testing.assert_close(norm, reference.broadcast_to(norm.shape)) if per_tensor: - torch.testing.assert_close( - norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape) - ) + torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)) assert overflow_buf.item() == 0 diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 2c9d61b845e3604ce3b3650662a8af03ba5e3cbc..4a3baf8823ca278493b287ad2ad27d9fd95db4b0 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -20,8 +20,15 @@ from transformer_engine.pytorch.utils import ( is_bf16_compatible, ) from transformer_engine.pytorch import ( - DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, - MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams + DotProductAttention, + LayerNormLinear, + LayerNormMLP, + Linear, + MultiheadAttention, + RMSNorm, + TransformerLayer, + LayerNorm, + InferenceParams, ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint @@ -106,10 +113,11 @@ def assert_allclose( if not result: diff = torch.abs(t1 - t2).flatten() m = torch.argmax(diff) - msg = (f"Outputs not close enough in tensor at idx={i}. " - f"Location of the maximum difference: {m.item()} " - f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} " - f"(diff {diff[m].item()})." + msg = ( + f"Outputs not close enough in tensor at idx={i}. " + f"Location of the maximum difference: {m.item()} " + f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} " + f"(diff {diff[m].item()})." ) raise AssertionError(msg) @@ -175,9 +183,7 @@ class TorchDotProductAttention(torch.nn.Module): ) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape( - output_size[2], output_size[0] * output_size[1], -1 - ) + query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) @@ -216,14 +222,10 @@ class TorchDotProductAttention(torch.nn.Module): ) # change view [sk, b * np, hn] - value_layer = value_layer.reshape( - value_layer.size(0), output_size[0] * output_size[1], -1 - ) + value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] - attention_probs = attention_probs.view( - output_size[0] * output_size[1], output_size[2], -1 - ) + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) @@ -241,9 +243,7 @@ class TorchDotProductAttention(torch.nn.Module): class TorchLayerNorm(nn.Module): - def __init__(self, in_features: int, - eps: float, - zero_centered_gamma: bool): + def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool): super().__init__() self.eps = eps self.in_features = in_features @@ -260,10 +260,12 @@ class TorchLayerNorm(nn.Module): w = w.to(torch.float32) b = self.bias.to(torch.float32) inp = x.to(torch.float32) - out = torch.nn.functional.layer_norm(inp, (self.in_features,), weight=w, - bias=b, eps=self.eps) + out = torch.nn.functional.layer_norm( + inp, (self.in_features,), weight=w, bias=b, eps=self.eps + ) return out.to(x.dtype) + # Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py class TorchRMSNorm(nn.Module): def __init__(self, in_features, zero_centered_gamma, eps=1e-5): @@ -278,11 +280,11 @@ class TorchRMSNorm(nn.Module): self.register_parameter("weight", self.weight) def forward(self, x): - norm_x2 = torch.sum(x.float()**2, dim=-1, keepdim=True) + norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True) d_x = self.in_features rms_x2 = norm_x2 / d_x + self.eps - r_rms_x = rms_x2 ** (-1. / 2) + r_rms_x = rms_x2 ** (-1.0 / 2) x_normed = x * r_rms_x w = self.weight.float() @@ -292,17 +294,24 @@ class TorchRMSNorm(nn.Module): class TorchLayerNormLinear(nn.Module): - def __init__(self, in_features: int, out_features: int, - eps: float, bias: bool = True, - normalization: str = "LayerNorm", - zero_centered_gamma: bool = False): + def __init__( + self, + in_features: int, + out_features: int, + eps: float, + bias: bool = True, + normalization: str = "LayerNorm", + zero_centered_gamma: bool = False, + ): super().__init__() if normalization == "LayerNorm": - self.layernorm = TorchLayerNorm(in_features, eps=eps, - zero_centered_gamma=zero_centered_gamma) + self.layernorm = TorchLayerNorm( + in_features, eps=eps, zero_centered_gamma=zero_centered_gamma + ) elif normalization == "RMSNorm": - self.layernorm = TorchRMSNorm(in_features, eps=eps, - zero_centered_gamma=zero_centered_gamma) + self.layernorm = TorchRMSNorm( + in_features, eps=eps, zero_centered_gamma=zero_centered_gamma + ) else: raise RuntimeError("Unsupported normalization") @@ -329,21 +338,26 @@ class TorchMHA(nn.Module): output = output[0] return output + class TorchQuickGELU(nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: return input * torch.sigmoid(1.702 * input) + class TorchSquaredRELU(nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: return (input > 0) * input * input -_supported_act = {'geglu' : nn.GELU(approximate="tanh"), - 'gelu' : nn.GELU(approximate="tanh"), - 'reglu' : nn.ReLU(), - 'relu' : nn.ReLU(), - 'swiglu' : nn.SiLU(), - 'qgelu' : TorchQuickGELU(), - 'srelu' : TorchSquaredRELU()} + +_supported_act = { + "geglu": nn.GELU(approximate="tanh"), + "gelu": nn.GELU(approximate="tanh"), + "reglu": nn.ReLU(), + "relu": nn.ReLU(), + "swiglu": nn.SiLU(), + "qgelu": TorchQuickGELU(), + "srelu": TorchSquaredRELU(), +} class TorchGLU(nn.Module): @@ -353,26 +367,29 @@ class TorchGLU(nn.Module): def forward(self, x): shape = x.size(-1) - a = x[..., :shape // 2] - b = x[..., (shape // 2):] + a = x[..., : shape // 2] + b = x[..., (shape // 2) :] a = self.act(a) return a * b class TorchLayerNormMLP(nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, - eps: float = 1e-5, activation = 'gelu', - normalization: str = "LayerNorm"): + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + eps: float = 1e-5, + activation="gelu", + normalization: str = "LayerNorm", + ): super().__init__() if normalization == "LayerNorm": - self.ln = TorchLayerNorm(hidden_size, eps=eps, - zero_centered_gamma=False) + self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False) elif normalization == "RMSNorm": - self.ln = TorchRMSNorm(hidden_size, eps=eps, - zero_centered_gamma=False) + self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False) else: raise RuntimeError("Unsupported normalization") - if 'glu' in activation: + if "glu" in activation: fc1_output_features = 2 * ffn_hidden_size self.gelu = TorchGLU(activation) else: @@ -387,7 +404,9 @@ class TorchLayerNormMLP(nn.Module): class TorchGPT(nn.Module): - def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool): + def __init__( + self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool + ): super().__init__() self.ln = nn.LayerNorm(hidden_size, eps=eps) self.causal_attn = TorchMHA(hidden_size, num_attention_heads) @@ -411,7 +430,6 @@ class TorchGPT(nn.Module): return x - def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -421,23 +439,21 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) with fp8_model_init(enabled=fp8 and fp8_model_params): - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - params_dtype=dtype, - fuse_qkv_params=True, - device="cuda", - ) + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + params_dtype=dtype, + fuse_qkv_params=True, + device="cuda", ) te_inp_hidden_states = torch.randn( @@ -477,8 +493,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par config = model_configs[model] - outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) - outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) + outputs = _test_e2e_selective_recompute( + bs, dtype, config, fp8, fp8_model_params, recompute=False + ) + outputs_recompute = _test_e2e_selective_recompute( + bs, dtype, config, fp8, fp8_model_params, recompute=True + ) # Check that results match tols = dtype_tols(dtype) @@ -496,10 +516,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par def _test_e2e_full_recompute( - bs, dtype, config, fp8, - fp8_model_params=False, - recompute=False, - use_reentrant=True + bs, dtype, config, fp8, fp8_model_params=False, recompute=False, use_reentrant=True ): reset_rng_states() FP8GlobalStateManager.reset() @@ -586,10 +603,12 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0" - outputs, names = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, - recompute=False, use_reentrant=use_reentrant) - outputs_recompute, _ = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, - recompute=True, use_reentrant=use_reentrant) + outputs, names = _test_e2e_full_recompute( + bs, dtype, config, fp8, fp8_model_params, recompute=False, use_reentrant=use_reentrant + ) + outputs_recompute, _ = _test_e2e_full_recompute( + bs, dtype, config, fp8, fp8_model_params, recompute=True, use_reentrant=use_reentrant + ) if not use_reentrant: # Reset bias+GELU fusion flag to avoid contaminating other tests @@ -753,22 +772,19 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] - te_gpt = ( - TransformerLayer( - hidden_size=config.hidden_size, - ffn_hidden_size=4 * config.hidden_size, - num_attention_heads=config.num_attention_heads, - layernorm_epsilon=config.eps, - attention_dropout=0.1, - hidden_dropout=0.1, - params_dtype=dtype, - fuse_qkv_params=True, - qkv_weight_interleaved=False, - parallel_attention_mlp=parallel_attention_mlp, - device="cuda", - ) - .eval() - ) + te_gpt = TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=4 * config.hidden_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.eps, + attention_dropout=0.1, + hidden_dropout=0.1, + params_dtype=dtype, + fuse_qkv_params=True, + qkv_weight_interleaved=False, + parallel_attention_mlp=parallel_attention_mlp, + device="cuda", + ).eval() torch_gpt = ( TorchGPT( @@ -853,18 +869,15 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] - te_mha = ( - MultiheadAttention( - config.hidden_size, - config.num_attention_heads, - fuse_qkv_params=True, - params_dtype=dtype, - qkv_weight_interleaved=False, - input_layernorm=False, - device="cuda", - ) - .eval() - ) + te_mha = MultiheadAttention( + config.hidden_size, + config.num_attention_heads, + fuse_qkv_params=True, + params_dtype=dtype, + qkv_weight_interleaved=False, + input_layernorm=False, + device="cuda", + ).eval() torch_mha = ( TorchMHA( @@ -919,7 +932,9 @@ def _test_granular_accuracy(block, bs, dtype, config): def _test_dpa_accuracy(block, bs, dtype, config): reset_rng_states() - mask = torch.triu(torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1) + mask = torch.triu( + torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1 + ) query, key, value = [ torch.randn( (config.seq_len, bs, config.num_attention_heads, config.embed), @@ -953,7 +968,7 @@ def test_dpa_accuracy(dtype, bs, model): DotProductAttention( config.num_attention_heads, config.embed, - attention_dropout=0.0, # disable dropout, FU uses rng differently + attention_dropout=0.0, # disable dropout, FU uses rng differently ) .to(dtype=dtype) .cuda() @@ -962,7 +977,7 @@ def test_dpa_accuracy(dtype, bs, model): torch_dpa = ( TorchDotProductAttention( config.embed, - 0.0, # dropout + 0.0, # dropout ) .to(dtype=dtype) .cuda() @@ -984,27 +999,21 @@ def test_dpa_accuracy(dtype, bs, model): def test_linear_accuracy(dtype, bs, model): config = model_configs[model] - te_linear = ( - Linear( - config.hidden_size, - 4 * config.hidden_size, - bias=True, - params_dtype=dtype, - device="cuda", - ) - .eval() - ) + te_linear = Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=True, + params_dtype=dtype, + device="cuda", + ).eval() - torch_linear = ( - torch.nn.Linear( - config.hidden_size, - 4 * config.hidden_size, - bias=True, - device="cuda", - dtype=dtype, - ) - .eval() - ) + torch_linear = torch.nn.Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=True, + device="cuda", + dtype=dtype, + ).eval() # Share params with torch.no_grad(): @@ -1029,23 +1038,16 @@ def test_linear_accuracy(dtype, bs, model): def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): config = model_configs[model] - te_rmsnorm = ( - RMSNorm( - config.hidden_size, - eps=eps, - params_dtype=dtype, - zero_centered_gamma=zero_centered_gamma, - device="cuda", - ) - .eval() - ) + te_rmsnorm = RMSNorm( + config.hidden_size, + eps=eps, + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", + ).eval() torch_rmsnorm = ( - TorchRMSNorm( - config.hidden_size, - eps=eps, - zero_centered_gamma=zero_centered_gamma - ) + TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma) .to(dtype=dtype) .cuda() .eval() @@ -1059,12 +1061,14 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config) # Check output. - atol = {torch.float32 : 1e-7, - torch.half : 2e-3, - torch.bfloat16: 2e-2, + atol = { + torch.float32: 1e-7, + torch.half: 2e-3, + torch.bfloat16: 2e-2, } assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -1073,23 +1077,16 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): config = model_configs[model] - te_layernorm = ( - LayerNorm( - config.hidden_size, - eps=eps, - params_dtype=dtype, - zero_centered_gamma=zero_centered_gamma, - device="cuda", - ) - .eval() - ) + te_layernorm = LayerNorm( + config.hidden_size, + eps=eps, + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", + ).eval() torch_layernorm = ( - TorchLayerNorm( - config.hidden_size, - eps=eps, - zero_centered_gamma=zero_centered_gamma - ) + TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma) .to(dtype=dtype) .cuda() .eval() @@ -1104,9 +1101,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config) # Check output. - atol = {torch.float32 : 1e-7, - torch.half : 2e-3, - torch.bfloat16: 2e-2, + atol = { + torch.float32: 1e-7, + torch.half: 2e-3, + torch.bfloat16: 2e-2, } assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) @@ -1119,19 +1117,16 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma): config = model_configs[model] - te_ln_linear = ( - LayerNormLinear( - config.hidden_size, - 4 * config.hidden_size, - config.eps, - bias=True, - normalization=normalization, - params_dtype=dtype, - zero_centered_gamma=zero_centered_gamma, - device="cuda", - ) - .eval() - ) + te_ln_linear = LayerNormLinear( + config.hidden_size, + 4 * config.hidden_size, + config.eps, + bias=True, + normalization=normalization, + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", + ).eval() torch_ln_linear = ( TorchLayerNormLinear( @@ -1159,9 +1154,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) # Check output. - atol = {torch.float32 : 2.5e-4, - torch.half : 2e-3, - torch.bfloat16: 2e-2, + atol = { + torch.float32: 2.5e-4, + torch.half: 2e-3, + torch.bfloat16: 2e-2, } assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) @@ -1174,17 +1170,14 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): config = model_configs[model] - te_ln_mlp = ( - LayerNormMLP( - config.hidden_size, - 4 * config.hidden_size, - activation=activation, - normalization=normalization, - params_dtype=dtype, - device="cuda", - ) - .eval() - ) + te_ln_mlp = LayerNormMLP( + config.hidden_size, + 4 * config.hidden_size, + activation=activation, + normalization=normalization, + params_dtype=dtype, + device="cuda", + ).eval() torch_ln_mlp = ( TorchLayerNormMLP( @@ -1226,8 +1219,10 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): optimizer = torch.optim.SGD(block.parameters(), lr=0.1) # Placeholders used for graph capture. - static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True) - static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype) + static_input = torch.randn( + config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True + ) + static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype) real_input = torch.rand_like(static_input) real_target = torch.rand_like(static_target) @@ -1286,22 +1281,20 @@ def test_gpt_cuda_graph(dtype, bs, model): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - params_dtype=dtype, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - device="cuda", - ) + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + device="cuda", ) graphed_block = copy.deepcopy(block) @@ -1388,7 +1381,6 @@ def test_gpt_fp8_parameters(dtype, bs, model): ) - @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -1451,7 +1443,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): requires_grad=True, ) - x_bshd = x_sbhd.transpose(0,1).contiguous() + x_bshd = x_sbhd.transpose(0, 1).contiguous() # To make sure forward is also identical (just in case some module decides # to act fancy) @@ -1466,7 +1458,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): # Check that results match torch.testing.assert_close( y_bshd, - y_sbhd.transpose(0,1).contiguous(), + y_sbhd.transpose(0, 1).contiguous(), ) @@ -1500,19 +1492,16 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, S_max = S + 2 if module == "TransformerLayer": - model = ( - TransformerLayer( - hidden_size=D, - ffn_hidden_size= 4 * D, - num_attention_heads=H, - attn_input_format=input_format, - layer_number=layer_number, - attention_dropout = 0.0, - params_dtype=dtype, - device="cuda", - ) - .eval() - ) + model = TransformerLayer( + hidden_size=D, + ffn_hidden_size=4 * D, + num_attention_heads=H, + attn_input_format=input_format, + layer_number=layer_number, + attention_dropout=0.0, + params_dtype=dtype, + device="cuda", + ).eval() else: model = ( MultiheadAttention( @@ -1520,7 +1509,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, num_attention_heads=H, qkv_format=input_format, layer_number=layer_number, - attention_dropout = 0.0, + attention_dropout=0.0, params_dtype=dtype, ) .cuda() @@ -1537,39 +1526,38 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, incremental_output = torch.zeros_like(input) # Generate output for the entire sequence - full_output = model( - hidden_states=input, - rotary_pos_emb=rotary_freqs if use_RoPE else None) + full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None) # Incrementaly generate outputs using KV-cache for i in range(S): if input_format == "sbhd": - incremental_input = input[i].view(1,B,D) + incremental_input = input[i].view(1, B, D) else: - incremental_input = input[:, i, :].view(B,1,D) + incremental_input = input[:, i, :].view(B, 1, D) line_output = model( hidden_states=incremental_input, inference_params=inference_params, - rotary_pos_emb=rotary_freqs if use_RoPE else None) + rotary_pos_emb=rotary_freqs if use_RoPE else None, + ) inference_params.sequence_len_offset += 1 if input_format == "sbhd": - incremental_output[i] = line_output.view(B,D) + incremental_output[i] = line_output.view(B, D) else: - incremental_output[:, i, :] = line_output.view(B,D) + incremental_output[:, i, :] = line_output.view(B, D) if module == "TransformerLayer": atol = { - torch.float32 : 5e-3, - torch.half : 5e-3, + torch.float32: 5e-3, + torch.half: 5e-3, torch.bfloat16: 5e-2, } else: atol = { - torch.float32 : 1e-3, - torch.half : 1e-3, + torch.float32: 1e-3, + torch.half: 1e-3, torch.bfloat16: 1e-2, } diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 1e9bbea507d8d86dec58a8f3e8cd325d53ee4edf..bdc459cdcc979fdb5daf4391283518aaf4a2e3ca 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -32,7 +32,13 @@ from typing import Optional, Union, Tuple, List import transformer_engine.pytorch as te from transformer_engine.common import recipe import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8 +from transformer_engine.pytorch.cpp_extensions import ( + gemm, + fp8_gemm, + gelu, + cast_to_fp8, + cast_from_fp8, +) from transformer_engine.pytorch.module.base import get_workspace import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.softmax as softmax_defs @@ -50,8 +56,10 @@ if SAVE_TEST_IO: from polygraphy.comparator import RunResults # The directory where generated ONNX test models are stored. -NVTE_TEST_ARTIFACTS_DIR = os.environ.get('NVTE_TEST_ARTIFACTS_DIR') -NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(tempfile.gettempdir(), "./gen_onnx_models") +NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR") +NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join( + tempfile.gettempdir(), "./gen_onnx_models" +) # The directory where this file is stored. @@ -100,23 +108,21 @@ def do_export( model: torch.nn.Module, inp: torch.Tensor, fname: str, - use_fp8: bool=True, - opset: int=OPSET, - input_names: List[str]=None, - output_names: List[str]=None, - dynamic_axes: List[str]=None + use_fp8: bool = True, + opset: int = OPSET, + input_names: List[str] = None, + output_names: List[str] = None, + dynamic_axes: List[str] = None, ): """Export to ONNX""" fp8_recipe = create_fp8_recipe() input_names = input_names or ["input"] output_names = output_names or ["output"] - with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): - warnings.filterwarnings( - action='ignore', - category=torch.jit.TracerWarning, - module=r'.*' - ) + with torch.inference_mode(), te.fp8_autocast( + enabled=use_fp8, fp8_recipe=fp8_recipe + ), warnings.catch_warnings(): + warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*") model.cuda().eval() os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True) @@ -138,7 +144,8 @@ def do_export( input_names=input_names, output_names=output_names, do_constant_folding=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH) + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH, + ) def to_numpy(tensor): @@ -154,24 +161,30 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. nb_total_scales = num_gemms * NB_SCALES_PER_GEMM module.init_fp8_metadata(num_gemms) - module.fp8_meta["scaling_fwd"].scale = torch.ones( - nb_total_scales, dtype=torch.float32, device="cuda") / scale - module.fp8_meta["scaling_fwd"].scale_inv = torch.ones( - nb_total_scales, dtype=torch.float32, device="cuda") * scale + module.fp8_meta["scaling_fwd"].scale = ( + torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") / scale + ) + module.fp8_meta["scaling_fwd"].scale_inv = ( + torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") * scale + ) def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool): """Transformer Engine forward propagation.""" fp8_recipe = create_fp8_recipe() - with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): + with torch.inference_mode(), te.fp8_autocast( + enabled=is_fp8, fp8_recipe=fp8_recipe + ), warnings.catch_warnings(): te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) if not isinstance(te_outputs, tuple): te_outputs = (te_outputs,) return te_outputs -def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname): - """ Compare ORT and TE outputs.""" +def compare_outputs( + onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname +): + """Compare ORT and TE outputs.""" assert len(onnx_outputs) == len(te_outputs) # Compare ORT and PyTorch outputs. for onnx_output, te_output in zip(onnx_outputs, te_outputs): @@ -192,11 +205,15 @@ def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, al errors = abs_err[mismatches] for loc in mismatched_ids[:nb_vals]: ref = te_output[loc] - print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}") + print( + f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >" + f" {atol + rtol * abs(ref)}" + ) print(f"Max error: {np.max(errors)}") if nb_errors > allow_cnt_errors: raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors") + def serialize_inputs_outputs( fname: str, inputs: Union[Tuple[torch.Tensor], torch.Tensor], @@ -214,10 +231,10 @@ def serialize_inputs_outputs( inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) named_inputs = zip(input_names, inputs) input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}] - json_fname = fname[:-len(".onnx")] + "_inputs.json" + json_fname = fname[: -len(".onnx")] + "_inputs.json" save_json(input_data, json_fname, description="custom input data") - json_fname = fname[:-len(".onnx")] + "_output.json" + json_fname = fname[: -len(".onnx")] + "_output.json" named_outputs = zip(output_names, te_outputs) output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None} custom_outputs = RunResults() @@ -229,14 +246,14 @@ def validate_result( fname: str, inps: Union[Tuple[torch.Tensor], torch.Tensor], model: torch.nn.Module, - atol: float=1.e-8, # np.isclose default atol - rtol: float=1.e-5, # np.isclose default rtol - max_errors_printed: int=10, - is_fp8: bool=False, - allow_cnt_errors: int=0, - input_names: List[str]=None, - output_names: List[str]=None, - te_outputs: List[torch.Tensor]=None, + atol: float = 1.0e-8, # np.isclose default atol + rtol: float = 1.0e-5, # np.isclose default rtol + max_errors_printed: int = 10, + is_fp8: bool = False, + allow_cnt_errors: int = 0, + input_names: List[str] = None, + output_names: List[str] = None, + te_outputs: List[torch.Tensor] = None, ): """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX representation using ONNX Runtime (ORT) and ensure they are close. @@ -262,7 +279,7 @@ def validate_result( print("registered custom FP8 Q/DQ ops!") """Create an ONNX Runtime session for validation.""" - kwargs = {"providers": ['CUDAExecutionProvider', 'CPUExecutionProvider']} + kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]} if is_fp8: sess_options = ort.SessionOptions() load_custom_ops(sess_options) @@ -288,10 +305,12 @@ def validate_result( ort_s = create_ort_session(fname, is_fp8) input_feed = create_ort_input_dict(ort_s, inps) onnx_outputs = ort_s.run(None, input_feed=input_feed) - compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname) + compare_outputs( + onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname + ) -def create_meta(scale_factor: float, size: int=1): +def create_meta(scale_factor: float, size: int = 1): meta = tex.FP8TensorMeta() meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor @@ -324,7 +343,9 @@ def get_attn_mask_str(use_mask, attn_mask_type): return "_mask" if use_mask else "_no-mask" attn_mask_str = "_arbitrary-no-mask" attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str - attn_mask_str = "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str + attn_mask_str = ( + "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str + ) return attn_mask_str @@ -351,17 +372,11 @@ class FP8GemmModule(nn.Module): self.outp_type = precision def forward(self, inp, weight): - inp_fp8 = cast_to_fp8( - inp, - self.meta_inp, - self.fp8_tensor_inp, - self.inp_type) + inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) weight_fp8 = cast_to_fp8( - weight, - self.meta_weight, - self.fp8_tensor_weight, - self.weights_type) + weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type + ) ret, _ = fp8_gemm( weight_fp8, @@ -376,9 +391,11 @@ class FP8GemmModule(nn.Module): get_workspace(), bias=self.bias, use_bias=self.use_bias, - use_split_accumulator=False) + use_split_accumulator=False, + ) return ret + """ Tests cases begin here. """ @@ -387,13 +404,17 @@ Tests cases begin here. @skip_FP8 @pytest.mark.parametrize("scale_factor", [1, 224]) @pytest.mark.parametrize( - "precision, atol", [ - [torch.float32, 1e-7], - [torch.float16, 1e-7], - [torch.bfloat16, 5e-3], - ["fake-torch.bfloat16", 5e-3], -]) -def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype): + "precision, atol", + [ + [torch.float32, 1e-7], + [torch.float16, 1e-7], + [torch.bfloat16, 5e-3], + ["fake-torch.bfloat16", 5e-3], + ], +) +def test_export_cast_ops( + seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype +): fake_bf16_io = precision == "fake-torch.bfloat16" # reset precision to torch.bfloat16 after capturing fake BF16 mode precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision @@ -408,18 +429,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre self.fake_bf16_io = fake_bf16_io def forward(self, inp): - ret = cast_to_fp8( - inp, - self.meta, - self.fp8_tensor, - self.fp8_type) + ret = cast_to_fp8(inp, self.meta, self.fp8_tensor, self.fp8_type) - ret = cast_from_fp8( - ret, - self.meta, - self.fp8_tensor, - self.fp8_type, - self.highprec_type) + ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) if self.fake_bf16_io: ret = ret.type(torch.float32) return ret @@ -427,8 +439,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre # Set dimensions (these are arbitrary). in_features = 64 hidden_size = 256 - inp = torch.randn(hidden_size, in_features, device="cuda", - dtype=torch.float if fake_bf16_io else precision) + inp = torch.randn( + hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision + ) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx" model = TestFP8_QDQ(fake_bf16_io) @@ -439,15 +452,18 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre if fake_bf16_io or precision != torch.bfloat16: validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs) + @skip_FP8 @pytest.mark.parametrize("scale_factor", [448]) @pytest.mark.parametrize( - "precision, atol", [ - [torch.float32, 1e-5], - [torch.float16, 1e-5], - [torch.bfloat16, 5e-3], - ["fake-torch.bfloat16", 5e-3] -]) + "precision, atol", + [ + [torch.float32, 1e-5], + [torch.float16, 1e-5], + [torch.bfloat16, 5e-3], + ["fake-torch.bfloat16", 5e-3], + ], +) def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): fake_bf16_io = precision == "fake-torch.bfloat16" # reset precision to torch.bfloat16 after capturing fake BF16 mode @@ -463,17 +479,8 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa self.fake_bf16_io = fake_bf16_io def forward(self, inp): - ret = gelu( - inp, - self.meta, - self.fp8_tensor, - self.fp8_type) - ret = cast_from_fp8( - ret, - self.meta, - self.fp8_tensor, - self.fp8_type, - self.highprec_type) + ret = gelu(inp, self.meta, self.fp8_tensor, self.fp8_type) + ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) if self.fake_bf16_io: ret = ret.type(torch.float32) return ret @@ -481,8 +488,9 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa # Set dimensions (these are arbitrary). in_features = 64 hidden_size = 256 - inp = torch.randn(hidden_size, in_features, device="cuda", - dtype=torch.float if fake_bf16_io else precision) + inp = torch.randn( + hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision + ) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" model = TestFP8_Gelu(fake_bf16_io) @@ -490,39 +498,55 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa te_outputs = te_infer(model, inp, is_fp8=True) serialize_inputs_outputs(fname, inp, te_outputs) if fake_bf16_io or precision != torch.bfloat16: - validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2, te_outputs=te_outputs) + validate_result( + fname, + inp, + model, + rtol=0, + atol=atol, + is_fp8=True, + allow_cnt_errors=2, + te_outputs=te_outputs, + ) -@pytest.mark.parametrize("scale_factors", - [(224, 224,), -]) @pytest.mark.parametrize( - "precision, use_fp8, use_bias, use_gelu", [ - (torch.float32, False, False, False), - (torch.float16, False, False, False), - (torch.bfloat16, False, False, False), - (torch.float32, False, True, False), - (torch.float16, False, True, False), - (torch.bfloat16, False, True, False), - (torch.float32, False, True, True), - (torch.float16, False, True, True), - (torch.bfloat16, False, True, True), - - # For FP8 GEMM GeLU is not used. - (torch.float32, True, False, False), - (torch.float16, True, False, False), - (torch.bfloat16, True, False, False), - # When enabling bias we must use float16 or bfloat16 (because of kernel limitations) - (torch.float16, True, True, False), - (torch.bfloat16, True, True, False), -]) + "scale_factors", + [ + ( + 224, + 224, + ), + ], +) +@pytest.mark.parametrize( + "precision, use_fp8, use_bias, use_gelu", + [ + (torch.float32, False, False, False), + (torch.float16, False, False, False), + (torch.bfloat16, False, False, False), + (torch.float32, False, True, False), + (torch.float16, False, True, False), + (torch.bfloat16, False, True, False), + (torch.float32, False, True, True), + (torch.float16, False, True, True), + (torch.bfloat16, False, True, True), + # For FP8 GEMM GeLU is not used. + (torch.float32, True, False, False), + (torch.float16, True, False, False), + (torch.bfloat16, True, False, False), + # When enabling bias we must use float16 or bfloat16 (because of kernel limitations) + (torch.float16, True, True, False), + (torch.bfloat16, True, True, False), + ], +) def test_export_gemm( seed_default_rng, - precision, # Precision of inputs, weights, output and bias + precision, # Precision of inputs, weights, output and bias use_fp8, use_bias, use_gelu, - scale_factors + scale_factors, ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: @@ -548,21 +572,20 @@ def test_export_gemm( inp, outp_type, get_workspace(), - # test bias bias=self.bias, use_bias=self.use_bias, - # test gelu gelu=self.gelu, gelu_input=self.gelu_input, - grad=False, # only True for backward pass + grad=False, # only True for backward pass accumulate=False, ) return ret # If gelu is applied then bias must be added, as defined by TE kernel. - if use_gelu: assert use_bias + if use_gelu: + assert use_bias # Set dimensions (these are arbitrary). out_features = 128 hidden_size = 256 @@ -574,45 +597,64 @@ def test_export_gemm( gelu_str = "_gelu" if use_gelu else "" high_prec_str = dtype2str(precision) fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx" - input_names = ['input', 'weight'] + input_names = ["input", "weight"] if use_fp8: - model = FP8GemmModule(precision, use_bias, use_gelu, scale_factors, hidden_size, out_features) + model = FP8GemmModule( + precision, use_bias, use_gelu, scale_factors, hidden_size, out_features + ) do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) if precision != torch.bfloat16: - validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, - is_fp8=True, input_names=input_names, te_outputs=te_outputs) + validate_result( + fname, + (inp, weight), + model, + rtol=1e-2, + atol=2e-2, + is_fp8=True, + input_names=input_names, + te_outputs=te_outputs, + ) else: model = Test_GEMM(precision, use_bias, use_gelu) do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) if precision != torch.bfloat16: - validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, - input_names=input_names, te_outputs=te_outputs) + validate_result( + fname, + (inp, weight), + model, + rtol=1e-2, + atol=2e-2, + input_names=input_names, + te_outputs=te_outputs, + ) @pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize( - "use_fp8, precision, atol", [ - [False, torch.float32, 1e-7], - [False, torch.float16, 1e-7], - [False, torch.bfloat16, 1e-7], - [False, "fake-torch.bfloat16", 1e-7], - [True, torch.float32, 1e-7], - [True, torch.float16, 1e-7], - [True, torch.bfloat16, 1e-2], - [True, "fake-torch.bfloat16", 1e-2] -]) + "use_fp8, precision, atol", + [ + [False, torch.float32, 1e-7], + [False, torch.float16, 1e-7], + [False, torch.bfloat16, 1e-7], + [False, "fake-torch.bfloat16", 1e-7], + [True, torch.float32, 1e-7], + [True, torch.float16, 1e-7], + [True, torch.bfloat16, 1e-2], + [True, "fake-torch.bfloat16", 1e-2], + ], +) def test_export_layernorm( seed_default_rng, use_fp8: bool, scale_factor: float, precision: torch.dtype, zero_centered_gamma: bool, - atol: float + atol: float, ): fake_bf16_io = precision == "fake-torch.bfloat16" # reset precision to torch.bfloat16 after capturing fake BF16 mode @@ -628,10 +670,15 @@ def test_export_layernorm( class Test_Layernorm(nn.Module): def __init__(self) -> None: super().__init__() - eps = 1e-6 # An arbitrary small value + eps = 1e-6 # An arbitrary small value dtype = torch.float if fake_bf16_io else precision - self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype, - zero_centered_gamma=zero_centered_gamma).eval().cuda() + self.ln = ( + te.LayerNorm( + inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma + ) + .eval() + .cuda() + ) def forward(self, inp): ret = self.ln(inp) @@ -641,11 +688,13 @@ def test_export_layernorm( def __init__(self) -> None: super().__init__() normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn(*normalized_shape, device="cuda", - dtype=torch.float32 if fake_bf16_io else precision) - self.bias = torch.zeros(*normalized_shape, device="cuda", - dtype=torch.float32 if fake_bf16_io else precision) - self.eps = 1e-6 # An arbitrary small value + self.weight = torch.randn( + *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision + ) + self.bias = torch.zeros( + *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision + ) + self.eps = 1e-6 # An arbitrary small value self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.meta = create_meta(scale_factor) @@ -661,14 +710,12 @@ def test_export_layernorm( self.fp8_tensor, self.fp8_type, 0, - zero_centered_gamma) + zero_centered_gamma, + ) ret = cast_from_fp8( - ret, - self.meta, - self.fp8_tensor, - self.fp8_type, - as_te_type(precision)) + ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision) + ) if fake_bf16_io: ret = ret.type(torch.float32) return ret @@ -683,28 +730,32 @@ def test_export_layernorm( serialize_inputs_outputs(fname, inp, te_outputs) if fake_bf16_io or precision != torch.bfloat16: validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) + fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs + ) + @pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize( - "use_fp8, precision, atol", [ - [False, torch.float32, 1e-7], - [False, torch.float16, 1e-7], - [False, torch.bfloat16, 1e-7], - [False, "fake-torch.bfloat16", 1e-7], - [True, torch.float32, 1e-7], - [True, torch.float16, 1e-7], - [True, torch.bfloat16, 1e-2], - [True, "fake-torch.bfloat16", 1e-2] -]) + "use_fp8, precision, atol", + [ + [False, torch.float32, 1e-7], + [False, torch.float16, 1e-7], + [False, torch.bfloat16, 1e-7], + [False, "fake-torch.bfloat16", 1e-7], + [True, torch.float32, 1e-7], + [True, torch.float16, 1e-7], + [True, torch.bfloat16, 1e-2], + [True, "fake-torch.bfloat16", 1e-2], + ], +) def test_export_rmsnorm( seed_default_rng, use_fp8: bool, scale_factor: float, precision: torch.dtype, zero_centered_gamma: bool, - atol: float + atol: float, ): fake_bf16_io = precision == "fake-torch.bfloat16" # reset precision to torch.bfloat16 after capturing fake BF16 mode @@ -720,10 +771,15 @@ def test_export_rmsnorm( class Test_RMSnorm(nn.Module): def __init__(self) -> None: super().__init__() - eps = 1e-6 # An arbitrary small value + eps = 1e-6 # An arbitrary small value dtype = torch.float if fake_bf16_io else precision - self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype, - zero_centered_gamma=zero_centered_gamma).eval().cuda() + self.ln = ( + te.RMSNorm( + inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma + ) + .eval() + .cuda() + ) def forward(self, inp): ret = self.ln(inp) @@ -733,9 +789,10 @@ def test_export_rmsnorm( def __init__(self) -> None: super().__init__() normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn(*normalized_shape, device="cuda", - dtype=torch.float32 if fake_bf16_io else precision) - self.eps = 1e-6 # An arbitrary small value + self.weight = torch.randn( + *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision + ) + self.eps = 1e-6 # An arbitrary small value self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.meta = create_meta(scale_factor) @@ -750,14 +807,12 @@ def test_export_rmsnorm( self.fp8_tensor, self.fp8_type, 0, - zero_centered_gamma) + zero_centered_gamma, + ) ret = cast_from_fp8( - ret, - self.meta, - self.fp8_tensor, - self.fp8_type, - as_te_type(precision)) + ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision) + ) if fake_bf16_io: ret = ret.type(torch.float32) return ret @@ -772,7 +827,8 @@ def test_export_rmsnorm( serialize_inputs_outputs(fname, inp, te_outputs) if fake_bf16_io or precision != torch.bfloat16: validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) + fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs + ) @pytest.mark.parametrize("scale_factor", [1]) @@ -780,23 +836,25 @@ def test_export_rmsnorm( # Returning the bias is a TE fusion optimization we don't care about. @pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize( - "precision, use_bias",[ - (torch.float32, False), - (torch.float32, True), - (torch.float16, False), - (torch.float16, True), - # Todo: cannot configure BF16 when bias is disabled (ORT issue?) - (torch.bfloat16, False), - # Todo: cannot configure BF16 when bias is enabled (ORT issue?) - (torch.bfloat16, True), -]) + "precision, use_bias", + [ + (torch.float32, False), + (torch.float32, True), + (torch.float16, False), + (torch.float16, True), + # Todo: cannot configure BF16 when bias is disabled (ORT issue?) + (torch.bfloat16, False), + # Todo: cannot configure BF16 when bias is enabled (ORT issue?) + (torch.bfloat16, True), + ], +) def test_export_linear( seed_default_rng, scale_factor: float, use_fp8: bool, use_bias: bool, return_bias: bool, - precision: torch.dtype + precision: torch.dtype, ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: @@ -808,20 +866,14 @@ def test_export_linear( hidden_size = 256 class Test_Linear(nn.Module): - def __init__(self, - in_features, - out_features, - use_bias, - return_bias, - precision - ): + def __init__(self, in_features, out_features, use_bias, return_bias, precision): super().__init__() self.linear = te.Linear( in_features, out_features, bias=use_bias, return_bias=return_bias, - params_dtype=precision + params_dtype=precision, ) def forward(self, inp): @@ -834,20 +886,16 @@ def test_export_linear( high_prec_str = dtype2str(precision) fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" with te.fp8_autocast(enabled=use_fp8): - model = Test_Linear( - in_features, - out_features, - use_bias, - return_bias, - precision - ).to(device='cuda') + model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to( + device="cuda" + ) if use_fp8: set_layer_scale(model.linear, scale_factor, num_gemms=1) do_export(model, inp, fname, use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs) - if precision in (torch.bfloat16, ): + if precision in (torch.bfloat16,): return if not use_fp8: validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) @@ -861,14 +909,16 @@ def test_export_linear( @pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize( - "precision, use_bias",[ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), -]) + "precision, use_bias", + [ + (torch.float32, False), + (torch.float32, True), + (torch.float16, True), + (torch.float16, False), + (torch.bfloat16, True), + (torch.bfloat16, False), + ], +) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("normalization", all_normalizations) def test_export_layernorm_linear( @@ -907,13 +957,13 @@ def test_export_layernorm_linear( params_dtype=precision, zero_centered_gamma=zero_centered_gamma, normalization=normalization, - ).to(device='cuda') + ).to(device="cuda") if use_fp8: set_layer_scale(model, scale_factor, num_gemms=1) do_export(model, inp, fname, use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs) - if precision in (torch.bfloat16, ): + if precision in (torch.bfloat16,): return if not use_fp8: validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) @@ -927,14 +977,16 @@ def test_export_layernorm_linear( @pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize( - "precision, use_bias",[ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), -]) + "precision, use_bias", + [ + (torch.float32, False), + (torch.float32, True), + (torch.float16, True), + (torch.float16, False), + (torch.bfloat16, True), + (torch.bfloat16, False), + ], +) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("activation", supported_activations) @pytest.mark.parametrize("normalization", all_normalizations) @@ -954,7 +1006,6 @@ def test_export_layernorm_mlp( if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) - # Set dimensions (these are arbitrary). in_features = 64 out_features = 256 @@ -977,30 +1028,32 @@ def test_export_layernorm_mlp( zero_centered_gamma=zero_centered_gamma, activation=activation, normalization=normalization, - ).to(device='cuda') + ).to(device="cuda") if use_fp8: set_layer_scale(model, scale_factor, num_gemms=2) do_export(model, inp, fname, use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs) - if precision in (torch.bfloat16, ): + if precision in (torch.bfloat16,): return - atol = 1e-6 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3) + atol = 1e-6 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3) validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs) @skip_FP8 @pytest.mark.parametrize( - "precision, use_mask, attn_mask_type", [ - (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) - (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) - (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) -]) + "precision, use_mask, attn_mask_type", + [ + (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) + (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) + (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) + (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) + (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) + (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) + (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) + (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) + ], +) def test_export_core_attention( seed_default_rng, set_max_seq_len, @@ -1034,40 +1087,42 @@ def test_export_core_attention( attention_dropout=0.5, qkv_format=qkv_format, attn_mask_type=attn_mask_type, - ).to(device='cuda') - do_export(model, - inp, - fname, - input_names=input_names, - use_fp8=True) + ).to(device="cuda") + do_export(model, inp, fname, input_names=input_names, use_fp8=True) te_outputs = te_infer(model, inp, is_fp8=True) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision in (torch.bfloat16, ): + if precision in (torch.bfloat16,): return - validate_result(fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs) + validate_result( + fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs + ) test_configs_multihead_attention = [ - #"use_mask, attn_mask_type" - (False, "no_mask"), # calls ScaledSoftmax - (True, "arbitrary"), # calls ScaledMaskedSoftmax + # "use_mask, attn_mask_type" + (False, "no_mask"), # calls ScaledSoftmax + (True, "arbitrary"), # calls ScaledMaskedSoftmax ] test_configs_attention_type = [ - #"input_layernorm, attention_type, fuse_qkv_params" - (True, "self", True), - (False, "self", True), - (True, "self", False), - (False, "self", False), - (True, "cross", True), - (False, "cross", True), - (True, "cross", False), - (False, "cross", False), + # "input_layernorm, attention_type, fuse_qkv_params" + (True, "self", True), + (False, "self", True), + (True, "self", False), + (False, "self", False), + (True, "cross", True), + (False, "cross", True), + (True, "cross", False), + (False, "cross", False), ] + + @pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type) +@pytest.mark.parametrize( + "input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type +) def test_export_multihead_attention( seed_default_rng, set_max_seq_len, @@ -1078,7 +1133,7 @@ def test_export_multihead_attention( return_layernorm_output: bool, input_layernorm: bool, attention_type: str, - fuse_qkv_params: bool + fuse_qkv_params: bool, ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: @@ -1102,17 +1157,23 @@ def test_export_multihead_attention( output_layer_init_method, ) - hidden_states_context = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") + hidden_states_context = torch.randn( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) attention_mask = None if use_mask and attn_mask_type != "causal": # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) + probs = 0.5 * torch.ones( + batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision + ) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) encoder_output = None if attention_type == "cross": - encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") + encoder_output = torch.randn( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) fp8_str = "_fp8" if use_fp8 else "" dtype_str = dtype2str(precision) @@ -1131,49 +1192,98 @@ def test_export_multihead_attention( attention_type=attention_type, fuse_qkv_params=fuse_qkv_params, return_bias=True, - ).to(device='cuda') + ).to(device="cuda") inp_context = (hidden_states_context, attention_mask, encoder_output) input_names = ["hidden_states", "attention_mask", "encoder_output"] - output_names=["attention_output", "attention_bias"] - do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names, - dynamic_axes={"hidden_states": {0: "seq", 1:"bs"}, - "attention_output": {0: "seq", 1:"bs"}}) + output_names = ["attention_output", "attention_bias"] + do_export( + model, + inp_context, + fname, + use_fp8, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "hidden_states": {0: "seq", 1: "bs"}, + "attention_output": {0: "seq", 1: "bs"}, + }, + ) te_outputs = te_infer(model, inp_context, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp_context, te_outputs, input_names=input_names, output_names=output_names) - if precision in (torch.bfloat16, ): + serialize_inputs_outputs( + fname, inp_context, te_outputs, input_names=input_names, output_names=output_names + ) + if precision in (torch.bfloat16,): return if not use_fp8: - validate_result(fname, inp_context, model, atol=1e-3, input_names=input_names, - output_names=output_names, te_outputs=te_outputs) + validate_result( + fname, + inp_context, + model, + atol=1e-3, + input_names=input_names, + output_names=output_names, + te_outputs=te_outputs, + ) else: - validate_result(fname, inp_context, model, atol=1e-2, is_fp8=use_fp8, - input_names=input_names, output_names=output_names, allow_cnt_errors=3, - te_outputs=te_outputs) + validate_result( + fname, + inp_context, + model, + atol=1e-2, + is_fp8=use_fp8, + input_names=input_names, + output_names=output_names, + allow_cnt_errors=3, + te_outputs=te_outputs, + ) # In GPT generative phase (inference) the input sequence is smaller than the maximum # allowed sequence length and we want to test this condition. # Pretend that we're in generative phase when it makes sense (causal mask and self-attention). - is_generative_phase = (attn_mask_type == "causal" and attention_type == "self") + is_generative_phase = attn_mask_type == "causal" and attention_type == "self" if is_generative_phase: seq_len_offset = 8 - hidden_states_generative = torch.randn(sequence_length-seq_len_offset, batch_size, hidden_size, dtype=precision, device="cuda") + hidden_states_generative = torch.randn( + sequence_length - seq_len_offset, + batch_size, + hidden_size, + dtype=precision, + device="cuda", + ) inp_generative = (hidden_states_generative, attention_mask, encoder_output) if not use_fp8: - validate_result(fname, inp_generative, model, atol=1e-3, input_names=input_names, output_names=output_names) + validate_result( + fname, + inp_generative, + model, + atol=1e-3, + input_names=input_names, + output_names=output_names, + ) else: - validate_result(fname, inp_generative, model, atol=1e-2, is_fp8=use_fp8, - input_names=input_names, output_names=output_names, allow_cnt_errors=3) - + validate_result( + fname, + inp_generative, + model, + atol=1e-2, + is_fp8=use_fp8, + input_names=input_names, + output_names=output_names, + allow_cnt_errors=3, + ) @pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize("output_layernorm", [ - #True, # TO DO: handle this - False -]) +@pytest.mark.parametrize( + "output_layernorm", + [ + # True, # TO DO: handle this + False + ], +) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("fuse_qkv_params", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @@ -1201,12 +1311,16 @@ def test_export_transformer_layer( ffn_hidden_size = 256 num_attention_heads = 4 - input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") + input_tensor = torch.rand( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) input_names = ["input", "attention_mask"] attention_mask = None if use_mask and attn_mask_type != "causal": # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) + probs = 0.5 * torch.ones( + batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision + ) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) inp = (input_tensor, attention_mask) @@ -1225,19 +1339,30 @@ def test_export_transformer_layer( params_dtype=precision, fuse_qkv_params=fuse_qkv_params, zero_centered_gamma=zero_centered_gamma, - activation=activation).to(device='cuda') + activation=activation, + ).to(device="cuda") do_export(model, inp, fname, use_fp8, input_names=input_names) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision in (torch.bfloat16, ): + if precision in (torch.bfloat16,): return - atol = 5e-1 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3) - validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs) + atol = 5e-1 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3) + validate_result( + fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs + ) @pytest.mark.parametrize("use_fp8", [True]) -@pytest.mark.parametrize("ln_scale_factor", [448*2]) -@pytest.mark.parametrize("gemm_scale_factors", [(224, 224,),]) +@pytest.mark.parametrize("ln_scale_factor", [448 * 2]) +@pytest.mark.parametrize( + "gemm_scale_factors", + [ + ( + 224, + 224, + ), + ], +) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) def test_export_gemm_layernorm( @@ -1246,7 +1371,7 @@ def test_export_gemm_layernorm( ln_scale_factor: float, gemm_scale_factors: Tuple[float, float], precision: torch.dtype, - zero_centered_gamma: bool + zero_centered_gamma: bool, ): """This is a regression test for testing that all LN inputs have the same type. @@ -1260,20 +1385,26 @@ def test_export_gemm_layernorm( # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + class TestFP8_GemmLayernorm(nn.Module): def __init__(self) -> None: super().__init__() normalized_shape = torch.Size(inp.shape[1:]) self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") - self.eps = 1e-6 # An arbitrary small value + self.eps = 1e-6 # An arbitrary small value self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.meta = create_meta(ln_scale_factor) self.fp8_type = tex.DType.kFloat8E4M3 self.gemm = FP8GemmModule( - precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors, - hidden_size=hidden_size, out_features=out_features) + precision, + use_bias=False, + gelu=False, + scale_factors=gemm_scale_factors, + hidden_size=hidden_size, + out_features=out_features, + ) def forward(self, inp, weight): x = self.gemm(inp, weight) @@ -1286,14 +1417,16 @@ def test_export_gemm_layernorm( self.fp8_tensor, self.fp8_type, 0, - zero_centered_gamma) + zero_centered_gamma, + ) x = cast_from_fp8( x, self.meta, self.fp8_tensor, self.fp8_type, - tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16) + tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16, + ) return x inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") @@ -1302,14 +1435,21 @@ def test_export_gemm_layernorm( high_prec_str = dtype2str(precision) fp8_str = f"_fp8" if use_fp8 else "" fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx" - input_names = ['input', 'weight'] + input_names = ["input", "weight"] do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision not in (torch.bfloat16, ): + if precision not in (torch.bfloat16,): validate_result( - fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2, - input_names=input_names, te_outputs=te_outputs) + fname, + (inp, weight), + model, + atol=5e-2, + is_fp8=use_fp8, + allow_cnt_errors=2, + input_names=input_names, + te_outputs=te_outputs, + ) @skip_FP8 @@ -1357,32 +1497,61 @@ def test_export_gpt_generation( output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, - zero_centered_gamma=zero_centered_gamma).to(device='cuda') + zero_centered_gamma=zero_centered_gamma, + ).to(device="cuda") # "Context phase": use full input sequence length input_names = ["input"] output_names = ["output"] - input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") + input_tensor = torch.rand( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) inp = (input_tensor,) - do_export(model, inp, fname, use_fp8, - input_names=input_names, output_names=output_names, - dynamic_axes={"input": {0: "seq", 1:"bs"}, - "output": {0: "seq", 1:"bs"}, }) + do_export( + model, + inp, + fname, + use_fp8, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "input": {0: "seq", 1: "bs"}, + "output": {0: "seq", 1: "bs"}, + }, + ) te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names, output_names=output_names) - if precision not in (torch.bfloat16, ): - validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names, - te_outputs=te_outputs) + serialize_inputs_outputs( + fname, inp, te_outputs, input_names=input_names, output_names=output_names + ) + if precision not in (torch.bfloat16,): + validate_result( + fname, + inp, + model, + atol=6e-3, + is_fp8=use_fp8, + input_names=input_names, + te_outputs=te_outputs, + ) # "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8. sequence_length = 1 if not use_fp8 else 8 - input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") + input_tensor = torch.rand( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) inp = (input_tensor, attention_mask) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision not in (torch.bfloat16, ): - validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names, - te_outputs=te_outputs) + if precision not in (torch.bfloat16,): + validate_result( + fname, + inp, + model, + atol=6e-3, + is_fp8=use_fp8, + input_names=input_names, + te_outputs=te_outputs, + ) @pytest.mark.parametrize("enabled", [True, False]) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 5c16da982b87ac073b94a5ed59d70f7b7f9f5eda..8c854b65fbf3a708d1b18c2b47bed188259dd623 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -19,6 +19,7 @@ from transformer_engine.pytorch.fp8 import ( # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFP8Recipe: @@ -95,8 +96,8 @@ class TestFP8Recipe: ref_amax_backward = amax_history_backward[-1] else: raise ValueError(f"{amax_compute_algo=} is not supported") - ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) - ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) + ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin) + ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin) ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) update_weight_amax = is_first_microbatch is None or is_first_microbatch if not update_weight_amax: @@ -128,8 +129,8 @@ class TestFP8Recipe: ref_amax_backward = amax_history_backward[-1] else: raise ValueError(f"{amax_compute_algo=} is not supported") - ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) - ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) + ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin) + ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin) ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) @@ -180,8 +181,9 @@ class TestFP8Recipe: scaling_factor_compute_algo = None if fused_update: scaling_factor_compute_algo = ( - lambda amax, scale, fp8_max, recipe: - te.fp8._default_sf_compute(amax, scale, fp8_max, recipe.margin) + lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute( + amax, scale, fp8_max, recipe.margin + ) ) recipe = transformer_engine.common.recipe.DelayedScaling( fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo @@ -205,7 +207,9 @@ class TestFP8Recipe: # test different scenarios if amax_case == "zero": - fp8_meta[forward_key].amax_history = torch.tensor([[0]], dtype=torch.float32, device="cuda") + fp8_meta[forward_key].amax_history = torch.tensor( + [[0]], dtype=torch.float32, device="cuda" + ) expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") elif amax_case == "tiny": # calculate the minimum amax value that results in a FP32 maximum scale @@ -254,4 +258,6 @@ class TestFP8Recipe: ) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) - torch.testing.assert_close(fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale)) + torch.testing.assert_close( + fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale) + ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 8f5913b70e4010f300498ba47b20daf2b85464db..4c9dea803f9590827049a2c28876ac237c6916dd 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -30,7 +30,13 @@ from transformer_engine.pytorch import ( ) from transformer_engine.common import recipe import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8 +from transformer_engine.pytorch.cpp_extensions import ( + gemm, + fp8_gemm, + gelu, + cast_to_fp8, + cast_from_fp8, +) from transformer_engine.pytorch.module.base import get_workspace from test_onnx_export import create_meta @@ -75,6 +81,7 @@ class ModelConfig: return False return True + model_configs = { "126m": ModelConfig(12, 2048, 2, 768, 12), "small": ModelConfig(2, 32, 2, 64, 2), @@ -82,7 +89,7 @@ model_configs = { } fp8_recipes = [ - None, # Handles non-FP8 case + None, # Handles non-FP8 case recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID), recipe.DelayedScaling( @@ -126,6 +133,7 @@ batch_sizes_with_zero = [0, 1, 2] all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"] all_normalizations = ["LayerNorm", "RMSNorm"] + def _disable_wgrads(block): for p in block.parameters(): p.requires_grad = False @@ -143,8 +151,17 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): optimizer = torch.optim.SGD(block.parameters(), lr=0.1) # Placeholders used for capture. - static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True) - static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype) + static_input = torch.randn( + config.seq_len, + config.batch_size, + config.hidden_size, + device="cuda", + dtype=dtype, + requires_grad=True, + ) + static_target = torch.randn( + config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype + ) real_input = torch.rand_like(static_input) real_target = torch.rand_like(static_target) @@ -403,11 +420,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz config = model_configs[model] module = RMSNorm if normalization == "RMSNorm" else LayerNorm - block = ( - module(config.hidden_size) - .to(dtype=torch.float32) - .cuda() - ) + block = module(config.hidden_size).to(dtype=torch.float32).cuda() _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) @@ -418,9 +431,9 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) -def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad, - zero_centered_gamma, skip_dgrad, - normalization): +def test_sanity_layernorm_linear( + dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization +): config = model_configs[model] if fp8_recipe is not None: @@ -480,7 +493,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size - num_tokens = bs*config.seq_len + num_tokens = bs * config.seq_len if fp8_recipe is not None: if not fp8_available: @@ -490,15 +503,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ use_fp8 = fp8_recipe is not None with fp8_model_init(enabled=use_fp8 and fp8_model_params): - te_linear = ( - Linear( - config.hidden_size, - ffn_hidden_size, - bias=use_bias, - params_dtype=dtype - ) - .cuda() - ) + te_linear = Linear( + config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype + ).cuda() inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True @@ -518,9 +525,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ @pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("normalization", all_normalizations) -def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, - zero_centered_gamma, skip_dgrad, activation, - normalization): +def test_sanity_layernorm_mlp( + dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization +): config = model_configs[model] if fp8_recipe is not None: @@ -557,10 +564,18 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) @pytest.mark.parametrize("cpu_offload", all_boolean) -def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, - zero_centered_gamma, bias, activation, - normalization, parallel_attention_mlp, - cpu_offload): +def test_sanity_gpt( + dtype, + fp8_recipe, + model, + skip_wgrad, + zero_centered_gamma, + bias, + activation, + normalization, + parallel_attention_mlp, + cpu_offload, +): config = model_configs[model] if fp8_recipe is not None: @@ -625,8 +640,7 @@ def test_sanity_gpt_126m(): @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) -def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, - normalization): +def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): config = model_configs[model] if fp8_recipe is not None: @@ -683,8 +697,7 @@ def test_sanity_bert_126m(): @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) -def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, - normalization): +def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): config = model_configs[model] if fp8_recipe is not None: @@ -845,7 +858,9 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) -def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma): +def test_sanity_gradient_accumulation_fusion( + dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma +): config = model_configs[model] if fp8_recipe is not None: @@ -885,8 +900,7 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) -def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, - normalization): +def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): config = model_configs[model] if fp8_recipe is not None: @@ -919,9 +933,10 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad) + def test_model_multiple_cast(): - a = torch.zeros((16,16), device="cuda") - m = Linear(16,32) + a = torch.zeros((16, 16), device="cuda") + m = Linear(16, 32) y = m(a) assert y.dtype == torch.float32 @@ -937,15 +952,11 @@ def test_model_multiple_cast(): @pytest.mark.parametrize("offset", [1, 3, 5]) @pytest.mark.parametrize("datatype", param_types) def test_sanity_gemm_with_unalignment(N, offset, datatype): - scratchpad = torch.randn(N*N + 2*offset, device="cuda", dtype=datatype) + scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype) inp = torch.reshape(scratchpad[offset:-offset], (N, N)) - weight = torch.reshape(scratchpad[offset*2:], (N, N)) + weight = torch.reshape(scratchpad[offset * 2 :], (N, N)) - _, _, _ = gemm( - A=weight, - B=inp, - dtype=datatype, - workspace=get_workspace()) + _, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace()) torch.cuda.synchronize() @@ -954,38 +965,35 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) def test_sanity_fp8_gemm_with_unalignment(N, datatype): offset = 16 - scratchpad = torch.randn(N*N + offset, device="cuda", dtype=datatype) + scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype) fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT nb_inp_scales, nb_weight_scales = 1, N - scale_factor = 1. + scale_factor = 1.0 meta_inp = create_meta(scale_factor, nb_inp_scales) meta_weight = create_meta(scale_factor, nb_weight_scales) inp_type = tex.DType.kFloat8E4M3 weights_type = tex.DType.kFloat8E4M3 outp_type = datatype - scratchpad_fp8 = cast_to_fp8( - scratchpad, - meta_weight, - fp8_tensor_inp, - inp_type) + scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type) inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N)) weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N)) _, _ = fp8_gemm( - weight_fp8, - meta_weight.scale_inv, - fp8_tensor_weight, - inp_type, - inp_fp8, - meta_inp.scale_inv, - fp8_tensor_inp, - weights_type, - outp_type, - get_workspace(), - bias=None, - use_bias=False, - use_split_accumulator=False) + weight_fp8, + meta_weight.scale_inv, + fp8_tensor_weight, + inp_type, + inp_fp8, + meta_inp.scale_inv, + fp8_tensor_inp, + weights_type, + outp_type, + get_workspace(), + bias=None, + use_bias=False, + use_split_accumulator=False, + ) torch.cuda.synchronize() diff --git a/tests/pytorch/test_sanity_import.py b/tests/pytorch/test_sanity_import.py index 4ddd6eb6e2f59755fe80dd85e69943f1bdf9fab4..954d807b7dac6aef81407201536bf59b34e71dca 100644 --- a/tests/pytorch/test_sanity_import.py +++ b/tests/pytorch/test_sanity_import.py @@ -3,4 +3,5 @@ # See LICENSE for license information. import transformer_engine.pytorch + print("OK") diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 74f76f7eb82e6218bb02f90cc8a3a6656f974c9a..6af7ede234095b5f382bba81d23cda2e893efabb 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -28,7 +28,7 @@ from transformer_engine.pytorch.module.base import TransformerEngineBaseModule fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -def init_meta(size: int=1): +def init_meta(size: int = 1): meta = tex.FP8TensorMeta() meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") @@ -65,22 +65,18 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd self.inp_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3 self.outp_type = precision - + def get_fp8_weights_scratchpad(self, is_first_microbatch): - raise RuntimeError("Method get_fp8_weights_scratchpad is dummy and should not be invoked.") + raise RuntimeError( + "Method get_fp8_weights_scratchpad is dummy and should not be invoked." + ) def forward(self, inp, weight): - inp_fp8 = cast_to_fp8( - inp, - self.meta_inp, - self.fp8_tensor_inp, - self.inp_type) + inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) weight_fp8 = cast_to_fp8( - weight, - self.meta_weight, - self.fp8_tensor_weight, - self.weights_type) + weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type + ) ret = fp8_gemm( weight_fp8, @@ -95,20 +91,33 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd get_workspace(), bias=self.bias, use_bias=self.use_bias, - use_split_accumulator=False) + use_split_accumulator=False, + ) return ret model_in = Test_TE_Export(precision, True) with te.fp8_autocast(enabled=True): model_in.init_fp8_metadata() # scaling fwd - model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd - model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd - model_in.fp8_meta["scaling_fwd"].amax_history = torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd + model_in.fp8_meta["scaling_fwd"].scale = ( + torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd + ) + model_in.fp8_meta["scaling_fwd"].scale_inv = ( + torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd + ) + model_in.fp8_meta["scaling_fwd"].amax_history = ( + torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd + ) # scaling bwd - model_in.fp8_meta["scaling_bwd"].scale = torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd - model_in.fp8_meta["scaling_bwd"].scale_inv = torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd - model_in.fp8_meta["scaling_bwd"].amax_history = torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd + model_in.fp8_meta["scaling_bwd"].scale = ( + torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd + ) + model_in.fp8_meta["scaling_bwd"].scale_inv = ( + torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd + ) + model_in.fp8_meta["scaling_bwd"].amax_history = ( + torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd + ) torch.save(model_in.state_dict(), tmp_filename) @@ -117,13 +126,27 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd model_out.eval() # scaling fwd - assert torch.allclose(model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale) - assert torch.allclose(model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv) - assert torch.allclose(model_in.fp8_meta["scaling_fwd"].amax_history, model_out.fp8_meta["scaling_fwd"].amax_history) + assert torch.allclose( + model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale + ) + assert torch.allclose( + model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv + ) + assert torch.allclose( + model_in.fp8_meta["scaling_fwd"].amax_history, + model_out.fp8_meta["scaling_fwd"].amax_history, + ) # scaling bwd - assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale) - assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv) - assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history) + assert torch.allclose( + model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale + ) + assert torch.allclose( + model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv + ) + assert torch.allclose( + model_in.fp8_meta["scaling_bwd"].amax_history, + model_out.fp8_meta["scaling_bwd"].amax_history, + ) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @@ -132,7 +155,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd def test_fp8_model_checkpoint( save_fp8_model: bool, load_fp8_model: bool, - dims: Iterable[int] = [32,32], + dims: Iterable[int] = [32, 32], dtype: torch.dtype = torch.float32, device: Union[torch.device, str] = "cuda", ): @@ -153,7 +176,7 @@ def test_fp8_model_checkpoint( with te.fp8_autocast(): y_ref = model(x.detach().clone()).detach().clone() - fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} } + fp8_meta_ref = {"scaling_fwd": {}, "scaling_bwd": {}} with te.fp8_autocast(), torch.no_grad(): fp8_meta_fwd = model.fp8_meta["scaling_fwd"] fp8_meta_bwd = model.fp8_meta["scaling_bwd"] @@ -168,7 +191,7 @@ def test_fp8_model_checkpoint( fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"]) fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) del fp8_meta_fwd, fp8_meta_bwd - + # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] # This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor. # The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method. @@ -226,15 +249,14 @@ def test_fp8_model_checkpoint( with pytest.raises(AssertionError): torch.testing.assert_close(y, y_ref, **tols) - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # When save_fp8_model=True, we load a model with weights in high precision, + # When save_fp8_model=True, we load a model with weights in high precision, # which does not include _scale_inv, - # but has the fp8 scaling factor in the meta data. This scenario can occur + # but has the fp8 scaling factor in the meta data. This scenario can occur # when using te.fp8_autocast(enabled=False, calibrating=True). # # In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first, - # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior + # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior # is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule, # to load the fp8 metadata before loading tensors. # @@ -262,4 +284,6 @@ def test_fp8_model_checkpoint( # We need to ensure that the tensor's scale_inv parameter matches its meta data. # This is crucial to avoid confusion about which value is correct. meta_index = model.weight._fp8_meta_index - torch.testing.assert_close(model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()) \ No newline at end of file + torch.testing.assert_close( + model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item() + ) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 12e1b37e8fc304c9fd971ad75e4f348a57ba8ce9..fc93705dff4866cb47c8f35ed006be6100563f24 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -4,74 +4,59 @@ * See LICENSE for license information. ************************************************************************/ -#include #include -#include "../util/vectorized_pointwise.h" -#include "../common.h" +#include +#include "../common.h" +#include "../util/vectorized_pointwise.h" namespace transformer_engine { -template -void act_fn(const Tensor &input, - Tensor *output, - cudaStream_t stream) { +template +void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { CheckInputTensor(input, "act_lu_input"); CheckOutputTensor(*output, "act_lu_output"); NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); const size_t tot_elts = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - tot_elts, - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), tot_elts, {}, + stream);); // NOLINT(*) + ); // NOLINT(*) } -template -void dact_fn(const Tensor &grad, - const Tensor &input, - Tensor *output, - cudaStream_t stream) { +template +void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { CheckInputTensor(input, "dact_lu_input"); CheckInputTensor(grad, "dact_lu_input_grad"); CheckOutputTensor(*output, "dact_lu_output"); NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - NVTE_CHECK(input.data.dtype == grad.data.dtype, - "Input and incoming gradient types must match."); + NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match."); const size_t tot_elts = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - tot_elts, - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), tot_elts, {}, + stream);); // NOLINT(*) + ); // NOLINT(*) } -template -void gated_act_fn(const Tensor &input, - Tensor *output, - cudaStream_t stream) { +template +void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { CheckInputTensor(input, "gated_act_input"); CheckOutputTensor(*output, "gated_act_output"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); @@ -81,29 +66,23 @@ void gated_act_fn(const Tensor &input, NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, "Input shape[1] must be 2x larger than output shape[1]."); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - output->data.shape[0], - output->data.shape[1], - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), output->data.shape[0], + output->data.shape[1], {}, + stream);); // NOLINT(*) + ); // NOLINT(*) } -template -void dgated_act_fn(const Tensor &grad, - const Tensor &input, - Tensor *output, - cudaStream_t stream) { +template +void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(input, "dgated_act_input"); CheckOutputTensor(*output, "dgated_act_output"); @@ -114,23 +93,19 @@ void dgated_act_fn(const Tensor &grad, "Output shape[0] must be equal to grad shape[0]."); NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, "Output shape[1] must be 2x larger than grad shape[1]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match."); + NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - grad.data.shape[0], - grad.data.shape[1], - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), grad.data.shape[0], grad.data.shape[1], + {}, + stream);); // NOLINT(*) + ); // NOLINT(*) } } // namespace transformer_engine - diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 9ed0a23cec4e7c68de3fbb79f9da055495cdf3b0..f9cd7b845a83ae90ad84950f03e6f8493824e8ab 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -3,96 +3,69 @@ * * See LICENSE for license information. ************************************************************************/ -#include "./activation_template.h" #include "../util/math.h" +#include "./activation_template.h" - -void nvte_gelu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_gelu); using namespace transformer_engine; act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_dgelu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); using namespace transformer_engine; dact_fn>(*reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_geglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_dgeglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; dgated_act_fn, dgelu>( - *reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + *reinterpret_cast(grad), *reinterpret_cast(input), + reinterpret_cast(output), stream); } -void nvte_qgelu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgelu); using namespace transformer_engine; act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_dqgelu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); using namespace transformer_engine; dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + *reinterpret_cast(input), + reinterpret_cast(output), stream); } -void nvte_qgeglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_dqgeglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; dgated_act_fn, dqgelu>( - *reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + *reinterpret_cast(grad), *reinterpret_cast(input), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index 5d6b02d5edc46b30553a10104cb60bd5c76499be..c18d018a8e0843bcb48c1a83d452c8209c73313e 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -4,96 +4,69 @@ * See LICENSE for license information. ************************************************************************/ -#include "./activation_template.h" #include "../util/math.h" +#include "./activation_template.h" - -void nvte_relu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_relu); using namespace transformer_engine; act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_drelu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); using namespace transformer_engine; dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + *reinterpret_cast(input), + reinterpret_cast(output), stream); } -void nvte_reglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_dreglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; dgated_act_fn, drelu>( - *reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + *reinterpret_cast(grad), *reinterpret_cast(input), + reinterpret_cast(output), stream); } -void nvte_srelu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_srelu); using namespace transformer_engine; act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_dsrelu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { NVTE_API_CALL(nvte_dsrelu); using namespace transformer_engine; dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + *reinterpret_cast(input), + reinterpret_cast(output), stream); } -void nvte_sreglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_dsreglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; dgated_act_fn, dsrelu>( - *reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + *reinterpret_cast(grad), *reinterpret_cast(input), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 736fb9e4373c9ae9a2a20b70a0f055f823374b6b..c745ffeeb4c5b004260f3e29f4dec87ae69ef742 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -4,51 +4,37 @@ * See LICENSE for license information. ************************************************************************/ -#include "./activation_template.h" #include "../util/math.h" +#include "./activation_template.h" - -void nvte_silu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_silu); using namespace transformer_engine; act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_dsilu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsilu); using namespace transformer_engine; dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + *reinterpret_cast(input), + reinterpret_cast(output), stream); } -void nvte_swiglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + reinterpret_cast(output), stream); } -void nvte_dswiglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; dgated_act_fn, dsilu>( - *reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + *reinterpret_cast(grad), *reinterpret_cast(input), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 92193c04ed80f723b11e3eba4e9a05d6114cae6d..42b529f3886aadc13dde8c615155a7a83e2e5505 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -7,6 +7,12 @@ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ +#include +#include +#include +#include +#include + #include #include #include @@ -15,12 +21,6 @@ #include #include -#include -#include -#include -#include - -#include #include "./nvtx.h" #include "./util/logging.h" @@ -31,8 +31,8 @@ struct SimpleTensor { std::vector shape; DType dtype; - SimpleTensor(void *dptr, const std::vector &shape, DType dtype) : - dptr(dptr), shape(shape), dtype(dtype) {} + SimpleTensor(void *dptr, const std::vector &shape, DType dtype) + : dptr(dptr), shape(shape), dtype(dtype) {} SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} }; @@ -42,15 +42,16 @@ struct Tensor { SimpleTensor scale; SimpleTensor scale_inv; - Tensor() : data(), - amax(nullptr, {1}, DType::kFloat32), - scale(nullptr, {1}, DType::kFloat32), - scale_inv(nullptr, {1}, DType::kFloat32) {} + Tensor() + : data(), + amax(nullptr, {1}, DType::kFloat32), + scale(nullptr, {1}, DType::kFloat32), + scale_inv(nullptr, {1}, DType::kFloat32) {} }; template constexpr T DIVUP(const T &x, const T &y) { - return (((x) + ((y)-1)) / (y)); + return (((x) + ((y)-1)) / (y)); } using byte = uint8_t; @@ -65,8 +66,11 @@ namespace detail { template constexpr inline const char *type_name() noexcept; -#define TRANSFORMER_ENGINE_TYPE_NAME(T) \ - template <> inline constexpr const char *type_name() noexcept { return #T; } +#define TRANSFORMER_ENGINE_TYPE_NAME(T) \ + template <> \ + inline constexpr const char *type_name() noexcept { \ + return #T; \ + } TRANSFORMER_ENGINE_TYPE_NAME(uint8_t) TRANSFORMER_ENGINE_TYPE_NAME(int32_t) TRANSFORMER_ENGINE_TYPE_NAME(float) @@ -79,214 +83,167 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) } // namespace detail template -struct TypeInfo{ - using types = std::tuple; - - template - struct Helper { - constexpr static DType getType() { - constexpr int i = static_cast(current); - if (std::is_same::type>::value) { - return current; - } else { - return Helper(i + 1)>::getType(); - } - } - }; - - template - struct Helper { - constexpr static DType getType() { - return DType::kNumTypes; - } - }; - - template +struct TypeInfo { + using types = std::tuple; + + template + struct Helper { constexpr static DType getType() { - return Helper::getType(); + constexpr int i = static_cast(current); + if (std::is_same::type>::value) { + return current; + } else { + return Helper(i + 1)>::getType(); + } } + }; + + template + struct Helper { + constexpr static DType getType() { return DType::kNumTypes; } + }; - constexpr static DType dtype = getType(); - constexpr static size_t size = sizeof(T); - constexpr static const char *name = detail::type_name(); + template + constexpr static DType getType() { + return Helper::getType(); + } + + constexpr static DType dtype = getType(); + constexpr static size_t size = sizeof(T); + constexpr static const char *name = detail::type_name(); }; #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kByte: \ - { \ - using type = unsigned char; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kInt32: \ - { \ - using type = float; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat32: \ - { \ - using type = float; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat16: \ - { \ - using type = fp16; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kBFloat16: \ - { \ - using type = bf16; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E4M3: \ - { \ - using type = fp8e4m3; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E5M2: \ - { \ - using type = fp8e5m2; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - NVTE_ERROR("Invalid type."); \ - } + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kByte: { \ + using type = unsigned char; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } #define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: \ - { \ - using type = float; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat16: \ - { \ - using type = fp16; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kBFloat16: \ - { \ - using type = bf16; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E5M2: \ - { \ - using type = fp8e5m2; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E4M3: \ - { \ - using type = fp8e4m3; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - NVTE_ERROR("Invalid type."); \ - } + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat8E5M2: \ - { \ - using type = fp8e5m2; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E4M3: \ - { \ - using type = fp8e4m3; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - NVTE_ERROR("Invalid type."); \ - } + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } #define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: \ - { \ - using type = float; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat16: \ - { \ - using type = fp16; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kBFloat16: \ - { \ - using type = bf16; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E5M2: \ - case DType::kFloat8E4M3: \ - { \ - NVTE_ERROR("FP8 type not instantiated for input."); \ - } \ - break; \ - default: \ - NVTE_ERROR("Invalid type."); \ - } - -#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ - switch (dtype) \ - { \ - using namespace transformer_engine; \ - case DType::kFloat16: \ - { \ - using type = fp16; \ - __VA_ARGS__; \ - break; \ - } \ - case DType::kBFloat16: \ - { \ - using type = bf16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - NVTE_ERROR("Invalid type for 16 bit."); \ - } + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: \ + case DType::kFloat8E4M3: { \ + NVTE_ERROR("FP8 type not instantiated for input."); \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat16: { \ + using type = fp16; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kBFloat16: { \ + using type = bf16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + NVTE_ERROR("Invalid type for 16 bit."); \ + } //////////////////////////////////////////////////////////////////////////////////////////////////// inline size_t product(const std::vector &shape) { - size_t ret = 1; - for (const auto &elem : shape) { - ret *= elem; - } - return ret; + size_t ret = 1; + for (const auto &elem : shape) { + ret *= elem; + } + return ret; } inline int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } template @@ -306,7 +263,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt bool is_fp8_dtype(const DType t); #define NVTE_API_CALL(api_name) \ - transformer_engine::nvtx::NVTXWrapper _ ## api_name ## _nvtx_wrapper(#api_name); + transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index a2eab9a70864195bbc2d1ec9f080048eae641233..e5a38793ba260941a0b39c33340696d712eed7ac 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -5,79 +5,74 @@ ************************************************************************/ #include "transformer_engine/fused_attn.h" + #include "../common.h" -#include "utils.h" -#include "fused_attn_f16_max512_seqlen.h" -#include "fused_attn_f16_arbitrary_seqlen.h" -#include "fused_attn_fp8.h" #include "../util/cuda_runtime.h" #include "../util/system.h" +#include "fused_attn_f16_arbitrary_seqlen.h" +#include "fused_attn_f16_max512_seqlen.h" +#include "fused_attn_fp8.h" +#include "utils.h" // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { - switch (qkv_layout) { - case NVTE_QKV_Layout::NVTE_SB3HD: - case NVTE_QKV_Layout::NVTE_BS3HD: - case NVTE_QKV_Layout::NVTE_T3HD: - return NVTE_QKV_Layout_Group::NVTE_3HD; - case NVTE_QKV_Layout::NVTE_SBH3D: - case NVTE_QKV_Layout::NVTE_BSH3D: - case NVTE_QKV_Layout::NVTE_TH3D: - return NVTE_QKV_Layout_Group::NVTE_H3D; - case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: - case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: - case NVTE_QKV_Layout::NVTE_THD_T2HD: - return NVTE_QKV_Layout_Group::NVTE_HD_2HD; - case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: - case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: - case NVTE_QKV_Layout::NVTE_THD_TH2D: - return NVTE_QKV_Layout_Group::NVTE_HD_H2D; - case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_THD_THD_THD: - return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; - default: - NVTE_ERROR("qkv_layout not supported!"); - } + switch (qkv_layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + return NVTE_QKV_Layout_Group::NVTE_3HD; + case NVTE_QKV_Layout::NVTE_SBH3D: + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + return NVTE_QKV_Layout_Group::NVTE_H3D; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + return NVTE_QKV_Layout_Group::NVTE_HD_2HD; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + return NVTE_QKV_Layout_Group::NVTE_HD_H2D; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } } // map NVTE_QKV_Layout to NVTE_QKV_Format NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { - switch (qkv_layout) { - case NVTE_QKV_Layout::NVTE_SB3HD: - case NVTE_QKV_Layout::NVTE_SBH3D: - case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: - case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: - case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: - return NVTE_QKV_Format::NVTE_SBHD; - case NVTE_QKV_Layout::NVTE_BS3HD: - case NVTE_QKV_Layout::NVTE_BSH3D: - case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: - case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: - return NVTE_QKV_Format::NVTE_BSHD; - case NVTE_QKV_Layout::NVTE_T3HD: - case NVTE_QKV_Layout::NVTE_TH3D: - case NVTE_QKV_Layout::NVTE_THD_T2HD: - case NVTE_QKV_Layout::NVTE_THD_TH2D: - case NVTE_QKV_Layout::NVTE_THD_THD_THD: - return NVTE_QKV_Format::NVTE_THD; - default: - NVTE_ERROR("qkv_layout not supported!"); - } + switch (qkv_layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + case NVTE_QKV_Layout::NVTE_SBH3D: + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + return NVTE_QKV_Format::NVTE_SBHD; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + return NVTE_QKV_Format::NVTE_BSHD; + case NVTE_QKV_Layout::NVTE_T3HD: + case NVTE_QKV_Layout::NVTE_TH3D: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + return NVTE_QKV_Format::NVTE_THD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } } // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, - NVTEDType kv_dtype, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - float dropout, - size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim) { + NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -85,96 +80,82 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); auto cudnn_runtime_version = cudnnGetVersion(); - if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) - || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) - && (sm_arch_ >= 90) - && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) - && ( - ((cudnn_runtime_version >= 8900) - && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) - && (max_seqlen_q == max_seqlen_kv) - && (max_seqlen_q <= 512) - && (head_dim == 64) - && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) - || ((cudnn_runtime_version >= 90100) - && (max_seqlen_q % 128 == 0) - && (max_seqlen_kv % 128 == 0) - && (head_dim == 128) - && ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) - || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) - && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) - || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) { + if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && + (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && + (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && + (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || + ((cudnn_runtime_version >= 90100) && (max_seqlen_q % 128 == 0) && + (max_seqlen_kv % 128 == 0) && (head_dim == 128) && + ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || + (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." << std::endl; + " Please upgrade your cuDNN version if possible." + << std::endl; } } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { bool flag_m512 = false; bool flag_arb = false; - if ((sm_arch_ == 80 || sm_arch_ == 90) - && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) - && (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) - && (head_dim == 64) - && (num_attn_heads == num_gqa_groups) - && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) - || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) - && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) - || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) - || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK - && max_seqlen_q == max_seqlen_kv) - || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) - && ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) - || (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) - || (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) - || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) { + if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && + (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim == 64) && + (num_attn_heads == num_gqa_groups) && + ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || + (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + max_seqlen_q == max_seqlen_kv) || + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) && + ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) || + (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) || + (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || + (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || + (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) { flag_m512 = true; } - if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) - || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) - && ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) - || (cudnn_runtime_version >= 90000)) - && ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) - || (cudnn_runtime_version >= 8907)) - && ((head_dim <= 128 && head_dim % 8 == 0) - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // d=256 only supported for forward - || (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 - && head_dim <= 256 && head_dim % 8 == 0)) - && ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) - || ((cudnn_runtime_version >= 8906) - && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS - || (bias_type == NVTE_Bias_Type::NVTE_ALIBI - && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK - && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK - && sm_arch_ >= 90) - || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS - && sm_arch_ >= 90))) - || ((cudnn_runtime_version >= 90000) - && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS - && sm_arch_ >= 80))) - && ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) - || ((cudnn_runtime_version >= 8906) - && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK - || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK - || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK - || attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))) - && (!(cudnn_runtime_version >= 8906 - && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK - || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) - && ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) - || (sm_arch_ >= 90 && cudnn_runtime_version >= 90100 - && num_attn_heads == num_gqa_groups - && qkv_format == NVTE_QKV_Format::NVTE_THD) - || (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { + if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || + (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && + ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || + (cudnn_runtime_version >= 90000)) && + ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || + (cudnn_runtime_version >= 8907)) && + ((head_dim <= 128 && head_dim % 8 == 0) + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // d=256 only supported for forward + || (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 && + head_dim % 8 == 0)) && + ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || + ((cudnn_runtime_version >= 8906) && + (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || + (bias_type == NVTE_Bias_Type::NVTE_ALIBI && + attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ >= 90) || + (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || + ((cudnn_runtime_version >= 90000) && + (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && + ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + ((cudnn_runtime_version >= 8906) && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))) && + (!(cudnn_runtime_version >= 8906 && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && + bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && + ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups && + qkv_format == NVTE_QKV_Format::NVTE_THD) || + (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { flag_arb = true; } - if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) - && (flag_arb == true)) { + if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { @@ -185,24 +166,26 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } int env_backend = static_cast(backend); env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); - if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) - && flag_m512) - || ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) - && flag_arb)) { - backend = static_cast(env_backend); + if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && + flag_m512) || + ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && + flag_arb)) { + backend = static_cast(env_backend); } } - if (cudnn_runtime_version < 8901 - && backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (cudnn_runtime_version < 8901 && + backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." - " Please upgrade your cuDNN version if possible." << std::endl; + " Please upgrade your cuDNN version if possible." + << std::endl; } - if (cudnn_runtime_version < 8900 - && backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + if (cudnn_runtime_version < 8900 && + backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." << std::endl; + " Please upgrade your cuDNN version if possible." + << std::endl; } } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; @@ -211,49 +194,40 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked( - const NVTETensor QKV, - const NVTETensor Bias, - NVTETensor S, - NVTETensor O, - NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - const NVTETensor rng_state, - size_t max_seqlen, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream) { +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); - const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); - const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); - const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); - const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); - const Tensor *input_rng_state = reinterpret_cast(rng_state); - const Tensor *input_QKV = reinterpret_cast(QKV); - const Tensor *input_Bias = reinterpret_cast(Bias); - Tensor *input_output_S = reinterpret_cast(S); - Tensor *output_O = reinterpret_cast(O); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_QKV = reinterpret_cast(QKV); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); auto ndim = input_QKV->data.shape.size(); size_t b = input_cu_seqlens->data.shape[0] - 1; size_t h = 0; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; + h = input_QKV->data.shape[ndim - 2]; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; + h = input_QKV->data.shape[ndim - 3]; } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); } size_t d = input_QKV->data.shape[ndim - 1]; @@ -261,49 +235,35 @@ void nvte_fused_attn_fwd_qkvpacked( const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend( - QKV_type, QKV_type, - qkv_layout, bias_type, attn_mask_type, - dropout, h, h, max_seqlen, max_seqlen, d); + nvte_get_fused_attn_backend(QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, + dropout, h, h, max_seqlen, max_seqlen, d); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_qkvpacked( - b, h, max_seqlen, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_QKV, input_Bias, output_O, - Aux_CTX_Tensors, - input_cu_seqlens, - input_rng_state, - wkspace, stream, handle); + fused_attn_max_512_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_QKV, input_Bias, + output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_QKV, input_Bias, output_O, - Aux_CTX_Tensors, - input_cu_seqlens, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, - input_rng_state, - wkspace, stream, handle); + fused_attn_arbitrary_seqlen_fwd_qkvpacked( + b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_qkvpacked( - b, h, max_seqlen, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_QKV, input_output_S, output_O, - Aux_CTX_Tensors, - input_cu_seqlens, - input_rng_state, - wkspace, stream, handle); + fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_QKV, input_output_S, output_O, + Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, + stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -313,52 +273,39 @@ void nvte_fused_attn_fwd_qkvpacked( } // NVTE fused attention BWD with packed QKV void nvte_fused_attn_bwd_qkvpacked( - const NVTETensor QKV, - const NVTETensor O, - const NVTETensor dO, - const NVTETensor S, - NVTETensor dP, - const NVTETensorPack* Aux_CTX_Tensors, - NVTETensor dQKV, - NVTETensor dBias, - const NVTETensor cu_seqlens, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - size_t max_seqlen, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream) { + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, size_t max_seqlen, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); - const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); - const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); - const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); - const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); - const Tensor *input_QKV = reinterpret_cast(QKV); - const Tensor *input_O = reinterpret_cast(O); - const Tensor *input_dO = reinterpret_cast(dO); - const Tensor *input_S = reinterpret_cast(S); - Tensor *input_output_dP = reinterpret_cast(dP); - Tensor *output_dQKV = reinterpret_cast(dQKV); - Tensor *output_dBias = reinterpret_cast(dBias); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); + const Tensor *input_QKV = reinterpret_cast(QKV); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQKV = reinterpret_cast(dQKV); + Tensor *output_dBias = reinterpret_cast(dBias); + Tensor *wkspace = reinterpret_cast(workspace); auto ndim = input_QKV->data.shape.size(); size_t b = input_cu_seqlens->data.shape[0] - 1; size_t h = 0; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; + h = input_QKV->data.shape[ndim - 2]; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; + h = input_QKV->data.shape[ndim - 3]; } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); } size_t d = input_QKV->data.shape[ndim - 1]; @@ -366,66 +313,48 @@ void nvte_fused_attn_bwd_qkvpacked( const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend( - QKV_type, QKV_type, - qkv_layout, bias_type, attn_mask_type, - dropout, h, h, max_seqlen, max_seqlen, d); + nvte_get_fused_attn_backend(QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, + dropout, h, h, max_seqlen, max_seqlen, d); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_qkvpacked( - b, h, max_seqlen, d, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_QKV, input_dO, - output_S, - output_dQKV, output_dBias, - input_cu_seqlens, - wkspace, stream, handle); + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd_qkvpacked( + b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, + input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - } - fused_attn_arbitrary_seqlen_bwd_qkvpacked( - b, h, max_seqlen, d, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_QKV, input_O, input_dO, input_Bias, - output_S, - output_dQKV, output_dBias, - input_cu_seqlens, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, - input_rng_state, - wkspace, stream, handle); + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *input_Bias, *input_rng_state; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + } else { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + } + fused_attn_arbitrary_seqlen_bwd_qkvpacked( + b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, + input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, + input_rng_state, wkspace, stream, handle); #else const char *err_msg = - "cuDNN 8.9.0 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; NVTE_ERROR(err_msg); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_qkvpacked( - b, h, max_seqlen, d, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_QKV, input_O, input_dO, - input_M, input_ZInv, - input_S, input_output_dP, - output_dQKV, - input_cu_seqlens, - input_rng_state, - wkspace, stream, handle); + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, + input_S, input_output_dP, output_dQKV, input_cu_seqlens, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -434,41 +363,31 @@ void nvte_fused_attn_bwd_qkvpacked( } } // NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, - const NVTETensor KV, - const NVTETensor Bias, - NVTETensor S, - NVTETensor O, - NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream) { +void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, + const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); - const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); - const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); - const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); - const Tensor *input_rng_state = reinterpret_cast(rng_state); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_KV = reinterpret_cast(KV); - const Tensor *input_Bias = reinterpret_cast(Bias); - Tensor *input_output_S = reinterpret_cast(S); - Tensor *output_O = reinterpret_cast(O); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_KV = reinterpret_cast(KV); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); size_t b = input_cu_seqlens_q->data.shape[0] - 1; auto ndim = input_Q->data.shape.size(); @@ -478,11 +397,11 @@ void nvte_fused_attn_fwd_kvpacked( size_t h_kv = 0; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; + h_kv = input_KV->data.shape[ndim_kv - 2]; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; + h_kv = input_KV->data.shape[ndim_kv - 3]; } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); @@ -490,49 +409,35 @@ void nvte_fused_attn_fwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend( - Q_type, KV_type, - qkv_layout, bias_type, attn_mask_type, - dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); + nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_KV, input_Bias, output_O, - Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, - wkspace, stream, handle); + fused_attn_max_512_fwd_kvpacked( + b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_KV, input_Bias, output_O, - Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, - input_rng_state, - wkspace, stream, handle); + fused_attn_arbitrary_seqlen_fwd_kvpacked( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k, + input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_KV, input_output_S, output_O, - Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, - wkspace, stream, handle); + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -542,46 +447,31 @@ void nvte_fused_attn_fwd_kvpacked( } // NVTE fused attention BWD with packed KV void nvte_fused_attn_bwd_kvpacked( - const NVTETensor Q, - const NVTETensor KV, - const NVTETensor O, - const NVTETensor dO, - const NVTETensor S, - NVTETensor dP, - const NVTETensorPack* Aux_CTX_Tensors, - NVTETensor dQ, - NVTETensor dKV, - NVTETensor dBias, - const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream) { + const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, + const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, + NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); - const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); - const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); - const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_KV = reinterpret_cast(KV); - const Tensor *input_O = reinterpret_cast(O); - const Tensor *input_dO = reinterpret_cast(dO); - const Tensor *input_S = reinterpret_cast(S); - Tensor *input_output_dP = reinterpret_cast(dP); - Tensor *output_dQ = reinterpret_cast(dQ); - Tensor *output_dKV = reinterpret_cast(dKV); - Tensor *output_dBias = reinterpret_cast(dBias); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_KV = reinterpret_cast(KV); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQ = reinterpret_cast(dQ); + Tensor *output_dKV = reinterpret_cast(dKV); + Tensor *output_dBias = reinterpret_cast(dBias); + Tensor *wkspace = reinterpret_cast(workspace); size_t b = input_cu_seqlens_q->data.shape[0] - 1; auto ndim = input_Q->data.shape.size(); @@ -591,11 +481,11 @@ void nvte_fused_attn_bwd_kvpacked( size_t h_kv = 0; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; + h_kv = input_KV->data.shape[ndim_kv - 2]; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; + h_kv = input_KV->data.shape[ndim_kv - 3]; } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); @@ -603,65 +493,51 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend( - Q_type, KV_type, - qkv_layout, bias_type, attn_mask_type, - dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); + nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_KV, input_dO, - output_S, - output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, - wkspace, stream, handle); + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd_kvpacked( + b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - } - fused_attn_arbitrary_seqlen_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_KV, input_O, input_dO, input_Bias, - output_S, - output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, - input_rng_state, wkspace, stream, handle); + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *input_Bias, *input_rng_state; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + } else { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + } + fused_attn_arbitrary_seqlen_bwd_kvpacked( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, + output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, + input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, + stream, handle); #else const char *err_msg = - "cuDNN 8.9.3 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; + "cuDNN 8.9.3 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; NVTE_ERROR(err_msg); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_KV, input_O, input_dO, - input_M, input_ZInv, - input_S, input_output_dP, - output_dQ, output_dKV, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, - wkspace, stream, handle); + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -670,43 +546,32 @@ void nvte_fused_attn_bwd_kvpacked( } } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd( - const NVTETensor Q, - const NVTETensor K, - const NVTETensor V, - const NVTETensor Bias, - NVTETensor S, - NVTETensor O, - NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream) { +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); - const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); - const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); - const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); - const Tensor *input_rng_state = reinterpret_cast(rng_state); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_K = reinterpret_cast(K); - const Tensor *input_V = reinterpret_cast(V); - const Tensor *input_Bias = reinterpret_cast(Bias); - Tensor *input_output_S = reinterpret_cast(S); - Tensor *output_O = reinterpret_cast(O); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_K = reinterpret_cast(K); + const Tensor *input_V = reinterpret_cast(V); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); auto ndim = input_Q->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; @@ -719,49 +584,35 @@ void nvte_fused_attn_fwd( const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend( - Q_type, KV_type, - qkv_layout, bias_type, attn_mask_type, - dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); + nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd( - b, h_q, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_K, input_V, input_Bias, output_O, - Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, - wkspace, stream, handle); + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_K, input_V, input_Bias, output_O, - Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, - input_rng_state, - wkspace, stream, handle); + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k, + input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_K, input_V, input_output_S, output_O, - Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, - wkspace, stream, handle); + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -770,51 +621,36 @@ void nvte_fused_attn_fwd( } } // NVTE fused attention BWD with separate Q, K and V -void nvte_fused_attn_bwd( - const NVTETensor Q, - const NVTETensor K, - const NVTETensor V, - const NVTETensor O, - const NVTETensor dO, - const NVTETensor S, - NVTETensor dP, - const NVTETensorPack* Aux_CTX_Tensors, - NVTETensor dQ, - NVTETensor dK, - NVTETensor dV, - NVTETensor dBias, - const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream) { +void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, + NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); - const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); - const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); - const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_K = reinterpret_cast(K); - const Tensor *input_V = reinterpret_cast(V); - const Tensor *input_O = reinterpret_cast(O); - const Tensor *input_dO = reinterpret_cast(dO); - const Tensor *input_S = reinterpret_cast(S); - Tensor *input_output_dP = reinterpret_cast(dP); - Tensor *output_dQ = reinterpret_cast(dQ); - Tensor *output_dK = reinterpret_cast(dK); - Tensor *output_dV = reinterpret_cast(dV); - Tensor *output_dBias = reinterpret_cast(dBias); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_K = reinterpret_cast(K); + const Tensor *input_V = reinterpret_cast(V); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQ = reinterpret_cast(dQ); + Tensor *output_dK = reinterpret_cast(dK); + Tensor *output_dV = reinterpret_cast(dV); + Tensor *output_dBias = reinterpret_cast(dBias); + Tensor *wkspace = reinterpret_cast(workspace); auto ndim = input_Q->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; @@ -827,65 +663,51 @@ void nvte_fused_attn_bwd( const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend( - Q_type, KV_type, - qkv_layout, bias_type, attn_mask_type, - dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); + nvte_get_fused_attn_backend(Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd( - b, h_q, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_K, input_V, input_dO, - output_S, - output_dQ, output_dK, output_dV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, - wkspace, stream, handle); + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, input_K, input_V, input_dO, output_S, + output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, + input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - } - fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_K, input_V, input_O, input_dO, input_Bias, - output_S, - output_dQ, output_dK, output_dV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, - input_rng_state, wkspace, stream, handle); + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *input_Bias, *input_rng_state; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + } else { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + } + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, + output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, + input_rng_state, wkspace, stream, handle); #else const char *err_msg = - "cuDNN 8.9.0 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; NVTE_ERROR(err_msg); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - input_Q, input_K, input_V, input_O, input_dO, - input_M, input_ZInv, - input_S, input_output_dP, - output_dQ, output_dK, output_dV, - input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, - wkspace, stream, handle); + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 9e34c40cdd7ebb8fd65c49bc4d29da79f1342ada..94dab77079e90fc07cbe663e0e4f6e6f7f8bc8ab 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -4,19 +4,19 @@ * See LICENSE for license information. ************************************************************************/ -#include "fused_attn_f16_arbitrary_seqlen.h" - #include #include #include #include + #include #include #include "../common.h" -#include "utils.h" #include "../util/cuda_runtime.h" #include "../util/system.h" +#include "fused_attn_f16_arbitrary_seqlen.h" +#include "utils.h" #if (CUDNN_VERSION >= 8900) #define Q_ID 1 @@ -48,683 +48,665 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, - int64_t bias_b, int64_t bias_h, - bool is_training, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxStats, void *devPtrO, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, - void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK, - void* devPtrSeqOffsetsV, void* devPtrSeqOffsetsO, - cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); - bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) - || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) - || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_dropout = (is_training && dropout_probability != 0.0f); - bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); - if (is_ragged) { - NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, + int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, + void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsK, + void *devPtrSeqOffsetsV, void *devPtrSeqOffsetsO, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); + bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_dropout = (is_training && dropout_probability != 0.0f); + bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); + if (is_ragged) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + } + + try { + FADescriptor_v1 descriptor{b, + h, + hg, + s_q, + s_kv, + d, + bias_b, + bias_h, + scaling_factor, + is_training, + dropout_probability, + layout, + bias_type, + mask_type, + tensorType, + tensorType}; + + namespace fe = cudnn_frontend; + using graph_and_tensors = + std::tuple, + std::shared_ptr, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // attn_scale + std::shared_ptr, // O + std::shared_ptr, // Stats + std::shared_ptr, // bias + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // offset_q + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v + std::shared_ptr, // offset_o + std::shared_ptr, // dropout_seed + std::shared_ptr>; // dropout_offset + + using CacheType = std::map; + static thread_local CacheType sdpa_f16_fprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor) -> graph_and_tensors { + // if hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto graph = it->second; + return graph; + } + + // otherwise, build the op_graph and the plan. Then update cache + auto mha_graph = std::make_shared(); + mha_graph->set_io_data_type(tensorType) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + std::shared_ptr Q, K, V, attn_scale; + std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr offset_q, offset_k, offset_v, offset_o; + std::shared_ptr dropout_seed, dropout_offset; + + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + + std::vector q_stride(4); + std::vector k_stride(4); + std::vector v_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + + if (is_ragged) { + Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride) + .set_ragged_offset(offset_q)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride) + .set_ragged_offset(offset_k)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride) + .set_ragged_offset(offset_v)); + } else { + Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride)); + } + + attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + fe::graph::SDPA_attributes sdpa_options; + sdpa_options = fe::graph::SDPA_attributes() + .set_name("flash_attention") + .set_is_inference(false) + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale); + + sdpa_options.set_alibi_mask(is_alibi); + + if (is_bias) { + bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, s_q, s_kv}) + .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + sdpa_options.set_bias(bias); + } + + if (is_padding) { + seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); + } + + if (is_dropout) { + dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + } + + auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); + + std::vector o_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + if (is_ragged) { + O->set_output(true) + .set_dim({b, h, s_q, d}) + .set_stride(o_stride) + .set_ragged_offset(offset_o); + } else { + O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + } + + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}); + + std::tuple, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // attn_scale + std::shared_ptr> // O + key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); + auto Stats_tuple = std::make_tuple(Stats); + auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto padding_tuple = + is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) + : std::make_tuple(nullptr, nullptr); + + NVTE_CHECK_CUDNN_FE(mha_graph->validate()); + NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); + NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); + + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, + bias_tuple, padding_tuple, offset_tuple, dropout_tuple); + cache.insert({descriptor, return_tuple}); + + return return_tuple; + }; + + auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, offset_q, offset_k, + offset_v, offset_o, dropout_seed, dropout_offset] = + get_graph(sdpa_f16_fprop_cache, descriptor); + + auto plan_workspace_size = mha_graph->get_workspace_size(); + // Exit to request upper level API to allocate memory if needed + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); + if (workspace == nullptr) { + *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + return; } - try { - FADescriptor_v1 descriptor{b, h, - hg, s_q, - s_kv, d, - bias_b, bias_h, - scaling_factor, is_training, - dropout_probability, layout, - bias_type, mask_type, - tensorType, tensorType}; - - namespace fe = cudnn_frontend; - using graph_and_tensors = std::tuple, - std::shared_ptr, // Q - std::shared_ptr, // K - std::shared_ptr, // V - std::shared_ptr, // attn_scale - std::shared_ptr, // O - std::shared_ptr, // Stats - std::shared_ptr, // bias - std::shared_ptr, // seq_q - std::shared_ptr, // seq_kv - std::shared_ptr, // offset_q - std::shared_ptr, // offset_k - std::shared_ptr, // offset_v - std::shared_ptr, // offset_o - std::shared_ptr, // dropout_seed - std::shared_ptr >; // dropout_offset - - using CacheType = std::map; - static thread_local CacheType sdpa_f16_fprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor) - -> graph_and_tensors { - // if hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto graph = it->second; - return graph; - } - - // otherwise, build the op_graph and the plan. Then update cache - auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(tensorType) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - std::shared_ptr Q, K, V, attn_scale; - std::shared_ptr bias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v, offset_o; - std::shared_ptr dropout_seed, dropout_offset; - - offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_q") - .set_dim({b+1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_k") - .set_dim({b+1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_v") - .set_dim({b+1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b+1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - if (is_ragged) { - Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride) - .set_ragged_offset(offset_q)); - K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride) - .set_ragged_offset(offset_k)); - V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride) - .set_ragged_offset(offset_v)); - } else { - Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - } - - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - - fe::graph::SDPA_attributes sdpa_options; - sdpa_options = fe::graph::SDPA_attributes() - .set_name("flash_attention") - .set_is_inference(false) - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale); - - sdpa_options.set_alibi_mask(is_alibi); - - if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); - sdpa_options.set_bias(bias); - } - - if (is_padding) { - seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("seq_kv") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - sdpa_options.set_padding_mask(is_padding) - .set_seq_len_q(seq_q) - .set_seq_len_kv(seq_kv); - } - - if (is_dropout) { - dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT64)); - dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT64)); - sdpa_options.set_dropout( - dropout_probability, dropout_seed, dropout_offset); - } - - auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); - - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - if (is_ragged) { - O->set_output(true) - .set_dim({b, h, s_q, d}) - .set_stride(o_stride) - .set_ragged_offset(offset_o); - } else { - O->set_output(true) - .set_dim({b, h, s_q, d}) - .set_stride(o_stride); - } - - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); - - std::tuple, // Q - std::shared_ptr, // K - std::shared_ptr, // V - std::shared_ptr, // attn_scale - std::shared_ptr > // O - key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); - auto Stats_tuple = std::make_tuple(Stats); - auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); - auto padding_tuple = is_padding ? - std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_tuple = is_ragged ? - std::make_tuple(offset_q, offset_k, offset_v, offset_o) : - std::make_tuple(nullptr, nullptr, nullptr, nullptr); - auto dropout_tuple = is_dropout ? - std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); - - NVTE_CHECK_CUDNN_FE(mha_graph->validate()); - NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); - NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); - NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); - NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - - auto return_tuple = std::tuple_cat( - std::make_tuple(mha_graph), key_tensors_tuple, - Stats_tuple, bias_tuple, padding_tuple, offset_tuple, dropout_tuple); - cache.insert({descriptor, return_tuple}); - - return return_tuple; - }; - - auto [mha_graph, Q, K, V, attn_scale, O, Stats, - bias, seq_q, seq_kv, offset_q, offset_k, offset_v, offset_o, - dropout_seed, dropout_offset] = get_graph( - sdpa_f16_fprop_cache, descriptor); - - auto plan_workspace_size = mha_graph->get_workspace_size(); - // Exit to request upper level API to allocate memory if needed - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - if (workspace == nullptr) { - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - // Build variant pack - std::unordered_map, void*> variant_pack = { - {Q, devPtrQ}, - {K, devPtrK}, - {V, devPtrV}, - {attn_scale, &scaling_factor}, - {O, devPtrO}, - {Stats, devPtrSoftmaxStats}}; - - if (is_bias) { - variant_pack[bias] = devPtrBias; - } + // Build variant pack + std::unordered_map, void *> variant_pack = { + {Q, devPtrQ}, {K, devPtrK}, + {V, devPtrV}, {attn_scale, &scaling_factor}, + {O, devPtrO}, {Stats, devPtrSoftmaxStats}}; - if (is_padding) { - constexpr size_t nthreads_per_block = 128; - const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); - cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlensQ), - static_cast(devPtrCuSeqlensKV), - static_cast(devActualSeqlenQ), - static_cast(devActualSeqlenKV)); - variant_pack[seq_q] = devActualSeqlenQ; - variant_pack[seq_kv] = devActualSeqlenKV; - } - - if (is_ragged) { - variant_pack[offset_q] = devPtrSeqOffsetsQ; - variant_pack[offset_k] = devPtrSeqOffsetsK; - variant_pack[offset_v] = devPtrSeqOffsetsV; - variant_pack[offset_o] = devPtrSeqOffsetsO; - } - - if (is_dropout) { - variant_pack[dropout_seed] = devPtrDropoutSeed; - variant_pack[dropout_offset] = devPtrDropoutOffset; - } + if (is_bias) { + variant_pack[bias] = devPtrBias; + } - NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); + if (is_padding) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, static_cast(devPtrCuSeqlensQ), + static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenKV)); + variant_pack[seq_q] = devActualSeqlenQ; + variant_pack[seq_kv] = devActualSeqlenKV; } -} -void fused_attn_arbitrary_seqlen_bwd_impl( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, - int64_t bias_b, int64_t bias_h, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose, - void* devPtrO, void* devPtrSoftmaxStats, void* devPtrBias, - void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdBias, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, - void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK, - void* devPtrSeqOffsetsV, void* devPtrSeqOffsetsO, - cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); - bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) - || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) - || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_dropout = (dropout_probability != 0.0f); - bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); if (is_ragged) { - NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + variant_pack[offset_q] = devPtrSeqOffsetsQ; + variant_pack[offset_k] = devPtrSeqOffsetsK; + variant_pack[offset_v] = devPtrSeqOffsetsV; + variant_pack[offset_o] = devPtrSeqOffsetsO; } - try { - FADescriptor_v1 descriptor{b, h, - hg, s_q, - s_kv, d, - bias_b, bias_h, - scaling_factor, true, - dropout_probability, layout, - bias_type, mask_type, - tensorType, tensorType}; - - namespace fe = cudnn_frontend; - using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // dO - std::shared_ptr, // stats - std::shared_ptr, // attn_scale - std::shared_ptr, // dQ - std::shared_ptr, // dK - std::shared_ptr, // dV - std::shared_ptr, // bias - std::shared_ptr, // dBias - std::shared_ptr, // seq_q - std::shared_ptr, // seq_kv - std::shared_ptr, // offset_q - std::shared_ptr, // offset_k - std::shared_ptr, // offset_v - std::shared_ptr, // offset_o - std::shared_ptr, // dropout_seed - std::shared_ptr >; // dropout_offset - - using CacheType = std::map; - static thread_local CacheType sdpa_f16_bprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor) - -> graph_and_tensors { - // if hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto graph = it->second; - return graph; - } - - // otherwise, build the op_graph and the plan. Then update cache - auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(tensorType) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr bias, dBias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v, offset_o; - std::shared_ptr dropout_seed, dropout_offset; - - offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_q") - .set_dim({b+1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_k") - .set_dim({b+1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_v") - .set_dim({b+1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b+1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - - if (is_ragged) { - q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride) - .set_ragged_offset(offset_q)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride) - .set_ragged_offset(offset_k)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride) - .set_ragged_offset(offset_v)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride) - .set_ragged_offset(offset_o)); - dO = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride) - .set_ragged_offset(offset_o)); - } else { - q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); - dO = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); - } - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - - fe::graph::SDPA_backward_attributes sdpa_backward_options; - sdpa_backward_options = fe::graph::SDPA_backward_attributes() - .set_name("flash_attention_backward") - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale); - - sdpa_backward_options.set_alibi_mask(is_alibi); - - if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); - dBias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dBias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); - sdpa_backward_options.set_bias(bias); - // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] - // are not supported for dbias calculation but they are - // supported for forward bias calculation - if ((bias_b == 1) && (bias_h == h)) { - sdpa_backward_options.set_dbias(dBias); - } - } - - if (is_padding) { - seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("seq_kv") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - sdpa_backward_options.set_padding_mask(is_padding) - .set_seq_len_q(seq_q) - .set_seq_len_kv(seq_kv); - } - - if (is_dropout) { - dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT64)); - dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT64)); - sdpa_backward_options.set_dropout( - dropout_probability, dropout_seed, dropout_offset); - } - - auto [dQ, dK, dV] = mha_graph->sdpa_backward( - q, k, v, o, dO, stats, sdpa_backward_options); - - if (is_ragged) { - dQ->set_output(true) - .set_dim({b, h, s_q, d}) - .set_stride(q_stride) - .set_ragged_offset(offset_q); - dK->set_output(true) - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride) - .set_ragged_offset(offset_k); - dV->set_output(true) - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride) - .set_ragged_offset(offset_v); - } else { - dQ->set_output(true) - .set_dim({b, h, s_q, d}) - .set_stride(q_stride); - dK->set_output(true) - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride); - dV->set_output(true) - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride); - } - - std::tuple, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // dO - std::shared_ptr, // stats - std::shared_ptr, // attn_scale - std::shared_ptr, // dQ - std::shared_ptr, // dK - std::shared_ptr > // dV - key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV); - auto bias_tuple = is_bias ? - std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); - auto padding_tuple = is_padding ? - std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_tuple = is_ragged ? - std::make_tuple(offset_q, offset_k, offset_v, offset_o) : - std::make_tuple(nullptr, nullptr, nullptr, nullptr); - auto dropout_tuple = is_dropout ? - std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); - - NVTE_CHECK_CUDNN_FE(mha_graph->validate()); - NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); - NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); - NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); - NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - - auto return_tuple = std::tuple_cat( - std::make_tuple(mha_graph), key_tensors_tuple, - bias_tuple, padding_tuple, offset_tuple, dropout_tuple); - cache.insert({descriptor, return_tuple}); - - return return_tuple; - }; - - auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, - bias, dBias, seq_q, seq_kv, offset_q, offset_k, offset_v, offset_o, - dropout_seed, dropout_offset] = get_graph( - sdpa_f16_bprop_cache, descriptor); - - auto plan_workspace_size = mha_graph->get_workspace_size(); - - // Exit to request upper level API to allocate memory if needed - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - if (workspace == nullptr) { - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } + if (is_dropout) { + variant_pack[dropout_seed] = devPtrDropoutSeed; + variant_pack[dropout_offset] = devPtrDropoutOffset; + } - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - // build variant pack - std::unordered_map, void*> variant_pack = { - {q, devPtrQ}, - {k, devPtrKTranspose}, - {v, devPtrVTranspose}, - {o, devPtrO}, - {dO, devPtrdO}, - {stats, devPtrSoftmaxStats}, - {attn_scale, &scaling_factor}, - {dQ, devPtrdQ}, - {dK, devPtrdK}, - {dV, devPtrdV}, - }; - - if (is_bias) { - variant_pack[bias] = devPtrBias; - if ((bias_b == 1) && (bias_h == h)) { - variant_pack[dBias] = devPtrdBias; - } else { - variant_pack[dBias] = nullptr; - } - } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); + } catch (cudnn_frontend::cudnnException &e) { + NVTE_ERROR(e.what()); + } +} - if (is_padding) { - constexpr size_t nthreads_per_block = 128; - const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); - cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlensQ), - static_cast(devPtrCuSeqlensKV), - static_cast(devActualSeqlenQ), - static_cast(devActualSeqlenKV)); - variant_pack[seq_q] = devActualSeqlenQ; - variant_pack[seq_kv] = devActualSeqlenKV; +void fused_attn_arbitrary_seqlen_bwd_impl( + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, + int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrKTranspose, + void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsK, + void *devPtrSeqOffsetsV, void *devPtrSeqOffsetsO, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); + bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_dropout = (dropout_probability != 0.0f); + bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); + if (is_ragged) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + } + + try { + FADescriptor_v1 descriptor{b, + h, + hg, + s_q, + s_kv, + d, + bias_b, + bias_h, + scaling_factor, + true, + dropout_probability, + layout, + bias_type, + mask_type, + tensorType, + tensorType}; + + namespace fe = cudnn_frontend; + using graph_and_tensors = + std::tuple, + std::shared_ptr, // q + std::shared_ptr, // k + std::shared_ptr, // v + std::shared_ptr, // o + std::shared_ptr, // dO + std::shared_ptr, // stats + std::shared_ptr, // attn_scale + std::shared_ptr, // dQ + std::shared_ptr, // dK + std::shared_ptr, // dV + std::shared_ptr, // bias + std::shared_ptr, // dBias + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // offset_q + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v + std::shared_ptr, // offset_o + std::shared_ptr, // dropout_seed + std::shared_ptr>; // dropout_offset + + using CacheType = std::map; + static thread_local CacheType sdpa_f16_bprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor) -> graph_and_tensors { + // if hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto graph = it->second; + return graph; + } + + // otherwise, build the op_graph and the plan. Then update cache + auto mha_graph = std::make_shared(); + mha_graph->set_io_data_type(tensorType) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + std::shared_ptr q, k, v, o, dO, stats, attn_scale; + std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr offset_q, offset_k, offset_v, offset_o; + std::shared_ptr dropout_seed, dropout_offset; + + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + std::vector q_stride(4); + std::vector k_stride(4); + std::vector v_stride(4); + std::vector o_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + + if (is_ragged) { + q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride) + .set_ragged_offset(offset_q)); + k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride) + .set_ragged_offset(offset_k)); + v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride) + .set_ragged_offset(offset_v)); + o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride) + .set_ragged_offset(offset_o)); + dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride) + .set_ragged_offset(offset_o)); + } else { + q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride)); + k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride)); + v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride)); + o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride)); + dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride)); + } + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + fe::graph::SDPA_backward_attributes sdpa_backward_options; + sdpa_backward_options = fe::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale); + + sdpa_backward_options.set_alibi_mask(is_alibi); + + if (is_bias) { + bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, s_q, s_kv}) + .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + dBias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dBias") + .set_dim({bias_b, bias_h, s_q, s_kv}) + .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + sdpa_backward_options.set_bias(bias); + // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] + // are not supported for dbias calculation but they are + // supported for forward bias calculation + if ((bias_b == 1) && (bias_h == h)) { + sdpa_backward_options.set_dbias(dBias); } + } + + if (is_padding) { + seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_backward_options.set_padding_mask(is_padding) + .set_seq_len_q(seq_q) + .set_seq_len_kv(seq_kv); + } + + if (is_dropout) { + dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + } + + auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); + + if (is_ragged) { + dQ->set_output(true) + .set_dim({b, h, s_q, d}) + .set_stride(q_stride) + .set_ragged_offset(offset_q); + dK->set_output(true) + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride) + .set_ragged_offset(offset_k); + dV->set_output(true) + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride) + .set_ragged_offset(offset_v); + } else { + dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); + dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); + dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); + } + + std::tuple, // q + std::shared_ptr, // k + std::shared_ptr, // v + std::shared_ptr, // o + std::shared_ptr, // dO + std::shared_ptr, // stats + std::shared_ptr, // attn_scale + std::shared_ptr, // dQ + std::shared_ptr, // dK + std::shared_ptr> // dV + key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV); + auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto padding_tuple = + is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) + : std::make_tuple(nullptr, nullptr); + + NVTE_CHECK_CUDNN_FE(mha_graph->validate()); + NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); + NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); + + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, + padding_tuple, offset_tuple, dropout_tuple); + cache.insert({descriptor, return_tuple}); + + return return_tuple; + }; + + auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, + offset_q, offset_k, offset_v, offset_o, dropout_seed, dropout_offset] = + get_graph(sdpa_f16_bprop_cache, descriptor); + + auto plan_workspace_size = mha_graph->get_workspace_size(); + + // Exit to request upper level API to allocate memory if needed + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); + if (workspace == nullptr) { + *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + return; + } - if (is_ragged) { - variant_pack[offset_q] = devPtrSeqOffsetsQ; - variant_pack[offset_k] = devPtrSeqOffsetsK; - variant_pack[offset_v] = devPtrSeqOffsetsV; - variant_pack[offset_o] = devPtrSeqOffsetsO; - } + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + // build variant pack + std::unordered_map, void *> variant_pack = { + {q, devPtrQ}, + {k, devPtrKTranspose}, + {v, devPtrVTranspose}, + {o, devPtrO}, + {dO, devPtrdO}, + {stats, devPtrSoftmaxStats}, + {attn_scale, &scaling_factor}, + {dQ, devPtrdQ}, + {dK, devPtrdK}, + {dV, devPtrdV}, + }; + + if (is_bias) { + variant_pack[bias] = devPtrBias; + if ((bias_b == 1) && (bias_h == h)) { + variant_pack[dBias] = devPtrdBias; + } else { + variant_pack[dBias] = nullptr; + } + } - if (is_dropout) { - variant_pack[dropout_seed] = devPtrDropoutSeed; - variant_pack[dropout_offset] = devPtrDropoutOffset; - } + if (is_padding) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, static_cast(devPtrCuSeqlensQ), + static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenKV)); + variant_pack[seq_q] = devActualSeqlenQ; + variant_pack[seq_kv] = devActualSeqlenKV; + } + + if (is_ragged) { + variant_pack[offset_q] = devPtrSeqOffsetsQ; + variant_pack[offset_k] = devPtrSeqOffsetsK; + variant_pack[offset_v] = devPtrSeqOffsetsV; + variant_pack[offset_o] = devPtrSeqOffsetsO; + } - NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); + if (is_dropout) { + variant_pack[dropout_seed] = devPtrDropoutSeed; + variant_pack[dropout_offset] = devPtrDropoutOffset; } + + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); + } catch (cudnn_frontend::cudnnException &e) { + NVTE_ERROR(e.what()); + } } } // namespace fused_attn @@ -736,595 +718,554 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; - void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; - void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; - void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; - - if (Aux_CTX_Tensors->size == 0) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; - output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + using namespace transformer_engine; + + const auto QKV_type = input_QKV->data.dtype; + void *devPtrQKV = input_QKV->data.dptr; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = typeToSize(QKV_type) * head_dim; + } + void *devPtrQ = static_cast(devPtrQKV); + void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); + + void *devPtrBias = nullptr; + size_t bias_b = 0; + size_t bias_h = 0; + if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { + devPtrBias = input_Bias->data.dptr; + bias_b = input_Bias->data.shape[0]; + bias_h = input_Bias->data.shape[1]; + } + void *devPtrO = output_O->data.dptr; + void *devPtrS = nullptr; + void *devPtrCuSeqlens = cu_seqlens->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; + + if (Aux_CTX_Tensors->size == 0) { + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Aux_CTX_Tensors->size = 3; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_bias->data.dptr = nullptr; + output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; + output_bias->data.dtype = QKV_type; } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + Aux_CTX_Tensors->size = 2; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; } - - void* devPtrDropoutSeed = rng_state->data.dptr; - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_attn_heads, - max_seqlen, max_seqlen, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, - devPtrSeqOffsetsV, devPtrSeqOffsetsO, - get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, - stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); + } else if (Aux_CTX_Tensors->size == 2) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = rng_state->data.dptr; + } else if (Aux_CTX_Tensors->size == 3) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = rng_state->data.dptr; + Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_bias->data.dptr = devPtrBias; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_arbitrary_seqlen_fwd_impl( + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, + is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, + devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, + devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } } -void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_heads, - size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, const Tensor *seq_offsets_q, - const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *seq_offsets_o, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; - } - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void* devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; - void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; - void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; - void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; - - void* devPtrDropoutSeed = rng_state->data.dptr; - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads, - max_seqlen, max_seqlen, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, - devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, - devPtrSeqOffsetsV, devPtrSeqOffsetsO, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); +void fused_attn_arbitrary_seqlen_bwd_qkvpacked( + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, + const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + + const auto QKV_type = input_QKV->data.dtype; + void *devPtrQKV = input_QKV->data.dptr; + + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = typeToSize(QKV_type) * head_dim; + } + void *devPtrQ = devPtrQKV; + void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); + + void *devPtrO = input_O->data.dptr; + void *devPtrdO = input_dO->data.dptr; + void *devPtrBias = nullptr; + void *devPtrdBias = nullptr; + size_t bias_b = 0; + size_t bias_h = 0; + if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { + devPtrBias = input_Bias->data.dptr; + devPtrdBias = output_dBias->data.dptr; + bias_b = output_dBias->data.shape[0]; + bias_h = output_dBias->data.shape[1]; + } + + void *devPtrdQKV = output_dQKV->data.dptr; + void *devPtrdQ = devPtrdQKV; + void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); + void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); + + void *devPtrSoftmaxStats = nullptr; + devPtrSoftmaxStats = output_S->data.dptr; + + void *devPtrCuSeqlens = cu_seqlens->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; + + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_arbitrary_seqlen_bwd_impl( + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, + devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsetsQ, + devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } } void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, - const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; - void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; - void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; - void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; - - if (Aux_CTX_Tensors->size == 0) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + + const auto QKV_type = input_Q->data.dtype; + void *devPtrQ = input_Q->data.dptr; + void *devPtrKV = input_KV->data.dptr; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = typeToSize(QKV_type) * head_dim; + } + void *devPtrK = devPtrKV; + void *devPtrV = static_cast(static_cast(devPtrKV) + stride); + + void *devPtrBias = nullptr; + size_t bias_b = 0; + size_t bias_h = 0; + if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { + devPtrBias = input_Bias->data.dptr; + bias_b = input_Bias->data.shape[0]; + bias_h = input_Bias->data.shape[1]; + } + void *devPtrO = output_O->data.dptr; + void *devPtrS = nullptr; + + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; + void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; + + if (Aux_CTX_Tensors->size == 0) { + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Aux_CTX_Tensors->size = 3; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_bias->data.dptr = nullptr; + output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.dtype = QKV_type; } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + Aux_CTX_Tensors->size = 2; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; } - - void* devPtrDropoutSeed = rng_state->data.dptr; - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, - devPtrSeqOffsetsV, devPtrSeqOffsetsO, - get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, - stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); + } else if (Aux_CTX_Tensors->size == 2) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = rng_state->data.dptr; + } else if (Aux_CTX_Tensors->size == 3) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = rng_state->data.dptr; + Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_bias->data.dptr = devPtrBias; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_arbitrary_seqlen_fwd_impl( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, + is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, + devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, + devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } } void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dKV, - Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, - const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *seq_offsets_o, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void* devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdKV = output_dKV->data.dptr; - void *devPtrdK = devPtrdKV; - void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; - void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; - void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; - void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; - - void* devPtrDropoutSeed = rng_state->data.dptr; - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, - devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, - devPtrSeqOffsetsV, devPtrSeqOffsetsO, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, + Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + + const auto QKV_type = input_Q->data.dtype; + void *devPtrQ = input_Q->data.dptr; + void *devPtrKV = input_KV->data.dptr; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = typeToSize(QKV_type) * head_dim; + } + void *devPtrK = devPtrKV; + void *devPtrV = static_cast(static_cast(devPtrKV) + stride); + + void *devPtrO = input_O->data.dptr; + void *devPtrdO = input_dO->data.dptr; + void *devPtrBias = nullptr; + void *devPtrdBias = nullptr; + size_t bias_b = 0; + size_t bias_h = 0; + if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { + devPtrBias = input_Bias->data.dptr; + devPtrdBias = output_dBias->data.dptr; + bias_b = output_dBias->data.shape[0]; + bias_h = output_dBias->data.shape[1]; + } + + void *devPtrdQ = output_dQ->data.dptr; + void *devPtrdKV = output_dKV->data.dptr; + void *devPtrdK = devPtrdKV; + void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); + + void *devPtrSoftmaxStats = nullptr; + devPtrSoftmaxStats = output_S->data.dptr; + + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; + void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; + + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_arbitrary_seqlen_bwd_impl( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, + devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrSeqOffsetsO, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } } void fused_attn_arbitrary_seqlen_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *seq_offsets_o, const Tensor *rng_state, + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_K->data.dptr; - void *devPtrV = input_V->data.dptr; - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; - void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; - void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; - void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; - - if (Aux_CTX_Tensors->size == 0) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + using namespace transformer_engine; + + const auto QKV_type = input_Q->data.dtype; + void *devPtrQ = input_Q->data.dptr; + void *devPtrK = input_K->data.dptr; + void *devPtrV = input_V->data.dptr; + void *devPtrO = output_O->data.dptr; + void *devPtrS = nullptr; + void *devPtrBias = nullptr; + size_t bias_b = 0; + size_t bias_h = 0; + if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { + devPtrBias = input_Bias->data.dptr; + bias_b = input_Bias->data.shape[0]; + bias_h = input_Bias->data.shape[1]; + } + + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; + void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; + + if (Aux_CTX_Tensors->size == 0) { + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Aux_CTX_Tensors->size = 3; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_bias->data.dptr = nullptr; + output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.dtype = QKV_type; } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + Aux_CTX_Tensors->size = 2; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; } - - void* devPtrDropoutSeed = rng_state->data.dptr; - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, - devPtrSeqOffsetsV, devPtrSeqOffsetsO, - get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, - stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); + } else if (Aux_CTX_Tensors->size == 2) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = rng_state->data.dptr; + } else if (Aux_CTX_Tensors->size == 3) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = rng_state->data.dptr; + Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_bias->data.dptr = devPtrBias; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_arbitrary_seqlen_fwd_impl( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, + is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, + devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, + devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } } -void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, - Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, - const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *seq_offsets_o, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_K->data.dptr; - void *devPtrV = input_V->data.dptr; - void* devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dK->data.dptr; - void *devPtrdV = output_dV->data.dptr; - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; - void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; - void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; - void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; - - void* devPtrDropoutSeed = rng_state->data.dptr; - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, - devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, - devPtrSeqOffsetsV, devPtrSeqOffsetsO, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); +void fused_attn_arbitrary_seqlen_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + const auto QKV_type = input_Q->data.dtype; + void *devPtrQ = input_Q->data.dptr; + void *devPtrK = input_K->data.dptr; + void *devPtrV = input_V->data.dptr; + void *devPtrO = input_O->data.dptr; + void *devPtrdO = input_dO->data.dptr; + void *devPtrBias = nullptr; + void *devPtrdBias = nullptr; + size_t bias_b = 0; + size_t bias_h = 0; + if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { + devPtrBias = input_Bias->data.dptr; + devPtrdBias = output_dBias->data.dptr; + bias_b = output_dBias->data.shape[0]; + bias_h = output_dBias->data.shape[1]; + } + + void *devPtrdQ = output_dQ->data.dptr; + void *devPtrdK = output_dK->data.dptr; + void *devPtrdV = output_dV->data.dptr; + void *devPtrSoftmaxStats = nullptr; + devPtrSoftmaxStats = output_S->data.dptr; + + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; + void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; + + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_arbitrary_seqlen_bwd_impl( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, + devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrSeqOffsetsO, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } } } // namespace transformer_engine #endif // CUDNN_VERSION >= 8900 diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 5959e830477c26d3f9b24f3dcea6aab96722b6f2..2a1b271db1bff1ece7828c6899643cdebd662ac2 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -11,91 +11,71 @@ #ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ #define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ -#include "transformer_engine/fused_attn.h" #include + #include "common/common.h" +#include "transformer_engine/fused_attn.h" namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens, const Tensor *seq_offsets_q, - const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q, + const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, - Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, const Tensor *seq_offsets_q, - const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, + const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, - const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, - const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, + Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, - const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, - const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index f6aa9110864cc243dc4cd8b8b48b9d894a1c5c03..88c1490c01a1146adfc1462e53b473f588f43baa 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -4,15 +4,15 @@ * See LICENSE for license information. ************************************************************************/ -#include "fused_attn_f16_max512_seqlen.h" - #include #include #include + #include #include #include "../common.h" +#include "fused_attn_f16_max512_seqlen.h" #include "utils.h" #if (CUDNN_VERSION >= 8901) @@ -44,119 +44,117 @@ namespace fused_attn { static void createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, cudnnDataType_t tensorType, std::vector &ops) { - // scale - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; + // scale + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; - int64_t k_dim[4] = {b, h, d, s_kv}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + int64_t k_dim[4] = {b, h, d, s_kv}; + int64_t k_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, + NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - auto scaleTensor = - tensor_create(tensorType, S_CONST_ID, scale_dim, scale_stride, false, true); // is by value - auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false); - auto afterScaleKTensor = - tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual + auto scaleTensor = + tensor_create(tensorType, S_CONST_ID, scale_dim, scale_stride, false, true); // is by value + auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false); + auto afterScaleKTensor = + tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual - // Define the scale descriptor - auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + // Define the scale descriptor + auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - // Create a Scale Node. - auto scale_op = binary_pw_op_create(kTensor, scaleTensor, afterScaleKTensor, scaleDesc); + // Create a Scale Node. + auto scale_op = binary_pw_op_create(kTensor, scaleTensor, afterScaleKTensor, scaleDesc); - ops.push_back(std::move(scale_op)); + ops.push_back(std::move(scale_op)); } static cudnn_frontend::Tensor createBMM1(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, cudnnDataType_t tensorType, - bool zero_s, - std::vector &ops) { - // Creates the necessary tensor descriptors - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, d, s_kv}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - - int64_t p_dim[4] = {b, h, s_q, s_kv}; - int64_t p_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); - auto afterScaleKTensor = - tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual - // first GEMM output - auto pTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, p_dim, p_stride, true, - false); // is virtual - - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - - // Define the matmul 1 desc - // set padding value optionally to 0 for writing zeros to S tensor (if not set, old behaviour) - auto matmul_1_Desc = - cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); - - if (zero_s) { - matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - } - - // Create a matmul 1 Node - auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(qTensor) - .setbMatDesc(afterScaleKTensor) - .setcMatDesc(pTensor) - .setmOverrideDesc(seqlenQTensor) - .setnOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_1_Desc) - .build(); + bool zero_s, std::vector &ops) { + // Creates the necessary tensor descriptors + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_dim[4] = {b, h, d, s_kv}; + int64_t k_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, + NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + + int64_t p_dim[4] = {b, h, s_q, s_kv}; + int64_t p_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + + auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); + auto afterScaleKTensor = + tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual + // first GEMM output + auto pTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, p_dim, p_stride, true, + false); // is virtual + + auto seqlenQTensor = + tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + auto seqlenKTensor = + tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + + // Define the matmul 1 desc + // set padding value optionally to 0 for writing zeros to S tensor (if not set, old behaviour) + auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); + + if (zero_s) { + matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) + .build(); + } + + // Create a matmul 1 Node + auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(qTensor) + .setbMatDesc(afterScaleKTensor) + .setcMatDesc(pTensor) + .setmOverrideDesc(seqlenQTensor) + .setnOverrideDesc(seqlenKTensor) + .setmatmulDesc(matmul_1_Desc) + .build(); - ops.push_back(std::move(matmul_op1)); + ops.push_back(std::move(matmul_op1)); - return pTensor; + return pTensor; } static cudnn_frontend::Tensor createBias(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, cudnnDataType_t tensorType, std::vector &ops, cudnn_frontend::Tensor const &prevBlockOutputTensor) { - NVTE_CHECK(ops.size() != 0, "Bias op constructed incorrectly as the first one."); + NVTE_CHECK(ops.size() != 0, "Bias op constructed incorrectly as the first one."); - int64_t b_dim[4] = {1, h, s_q, s_kv}; - int64_t b_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + int64_t b_dim[4] = {1, h, s_q, s_kv}; + int64_t b_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - int64_t afterBias_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBias_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, afterBias_stride, layout, - NVTE_QKV_Matrix::NVTE_S_Matrix); + int64_t afterBias_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBias_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, afterBias_stride, layout, + NVTE_QKV_Matrix::NVTE_S_Matrix); - // bias - auto bTensor = tensor_create(tensorType, B_ID, b_dim, b_stride, false, false); - // output - auto afterBiasTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 50, afterBias_dim, - afterBias_stride, true, false); // is virtual + // bias + auto bTensor = tensor_create(tensorType, B_ID, b_dim, b_stride, false, false); + // output + auto afterBiasTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 50, afterBias_dim, + afterBias_stride, true, false); // is virtual - // Define the bias descriptor - auto biasDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD); + // Define the bias descriptor + auto biasDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD); - // Create a Bias Node. - auto bias_op = binary_pw_op_create(prevBlockOutputTensor, bTensor, afterBiasTensor, biasDesc); + // Create a Bias Node. + auto bias_op = binary_pw_op_create(prevBlockOutputTensor, bTensor, afterBiasTensor, biasDesc); - ops.push_back(std::move(bias_op)); + ops.push_back(std::move(bias_op)); - return afterBiasTensor; + return afterBiasTensor; } static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, @@ -165,160 +163,157 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6 std::vector &ops, cudnn_frontend::Tensor const &prevBlockOutputTensor, bool is_bprop) { - NVTE_CHECK(ops.size() != 0, "Padding mask constructed incorrectly as the first one."); - - // subtraction output - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - int64_t maskVal_dim[4] = {1, 1, 1, 1}; - int64_t maskVal_stride[4] = {1, 1, 1, 1}; - - // mask value to put in the masked pixels - auto maskValTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, - false, true); // is by value - - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - // gen index row output - auto rowIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // gen index column output - auto columnIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // less than row output - auto lessThanRowTensor = - tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 102, afterBMM1_dim, afterBMM1_stride, true, - false); // is virtual - // less than column output - auto lessThanColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 103, afterBMM1_dim, + NVTE_CHECK(ops.size() != 0, "Padding mask constructed incorrectly as the first one."); + + // subtraction output + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + + int64_t maskVal_dim[4] = {1, 1, 1, 1}; + int64_t maskVal_stride[4] = {1, 1, 1, 1}; + + // mask value to put in the masked pixels + auto maskValTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, + false, true); // is by value + + auto seqlenQTensor = + tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + auto seqlenKTensor = + tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + // gen index row output + auto rowIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + // gen index column output + auto columnIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + // less than row output + auto lessThanRowTensor = + tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 102, afterBMM1_dim, afterBMM1_stride, true, + false); // is virtual + // less than column output + auto lessThanColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 103, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + // padding mask (lessthanRow && lessthanCol) + auto paddingMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 104, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + // row >= col check for causal mask + auto rowGreaterColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 105, afterBMM1_dim, afterBMM1_stride, true, false); // is virtual - // padding mask (lessthanRow && lessthanCol) - auto paddingMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 104, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // row >= col check for causal mask - auto rowGreaterColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 105, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // create causal mask (padding && row >= col) - auto causalMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // output after masking - int64_t maskOutputTensor_id = VIRTUAL_ID + 107; - int64_t maskOutputTensor_virtual = true; - cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT; - auto maskOutputTensor_reorderType = - cudnn_frontend::TensorReordering_t::NONE; - - if (is_bprop) { - maskOutputTensor_id = dS_ID; - maskOutputTensor_virtual = false; - maskOutputTensor_dataType = tensorType; - maskOutputTensor_reorderType = - cudnn_frontend::TensorReordering_t::F16x16; - } + // create causal mask (padding && row >= col) + auto causalMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual - auto maskOutputTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setByValue(false) - .setDataType(maskOutputTensor_dataType) - .setVirtual(maskOutputTensor_virtual) - .setId(maskOutputTensor_id) - .setReorderType(maskOutputTensor_reorderType) - .build(); - - // Define the gen index for row descriptor - auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setAxis(2) - .setComputeType(CUDNN_DATA_FLOAT) - .build(); + // output after masking + int64_t maskOutputTensor_id = VIRTUAL_ID + 107; + int64_t maskOutputTensor_virtual = true; + cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT; + auto maskOutputTensor_reorderType = cudnn_frontend::TensorReordering_t::NONE; + + if (is_bprop) { + maskOutputTensor_id = dS_ID; + maskOutputTensor_virtual = false; + maskOutputTensor_dataType = tensorType; + maskOutputTensor_reorderType = cudnn_frontend::TensorReordering_t::F16x16; + } + + auto maskOutputTensor = + cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setByValue(false) + .setDataType(maskOutputTensor_dataType) + .setVirtual(maskOutputTensor_virtual) + .setId(maskOutputTensor_id) + .setReorderType(maskOutputTensor_reorderType) + .build(); + + // Define the gen index for row descriptor + auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) + .setAxis(2) + .setComputeType(CUDNN_DATA_FLOAT) + .build(); - // Create a gen index Node. - auto genIndexRow_op = - unary_pw_op_create(prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc); + // Create a gen index Node. + auto genIndexRow_op = unary_pw_op_create(prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc); - // Define the gen index for row descriptor - auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setAxis(3) - .setComputeType(CUDNN_DATA_FLOAT) - .build(); + // Define the gen index for row descriptor + auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) + .setAxis(3) + .setComputeType(CUDNN_DATA_FLOAT) + .build(); - // Create a gen index Node. - auto genIndexColumn_op = - unary_pw_op_create(prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc); + // Create a gen index Node. + auto genIndexColumn_op = + unary_pw_op_create(prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc); - // Define the less than comparison for row descriptor - auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); + // Define the less than comparison for row descriptor + auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); - // Create a less than comparison for row Node. - auto lessThanRow_op = - binary_pw_op_create(rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc); + // Create a less than comparison for row Node. + auto lessThanRow_op = + binary_pw_op_create(rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc); - // Define the less than comparison for column descriptor - auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); + // Define the less than comparison for column descriptor + auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); - // Create a less than comparison for col Node. - auto lessThanCol_op = - binary_pw_op_create(columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc); + // Create a less than comparison for col Node. + auto lessThanCol_op = + binary_pw_op_create(columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc); - // Define the less than comparison for column descriptor - auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND); + // Define the less than comparison for column descriptor + auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND); - // Create a and node for combining lessThanRow and lessThanCol - auto paddingMaskAnd_op = binary_pw_op_create(lessThanRowTensor, lessThanColTensor, - paddingMaskTensor, paddingMaskAndDesc); + // Create a and node for combining lessThanRow and lessThanCol + auto paddingMaskAnd_op = binary_pw_op_create(lessThanRowTensor, lessThanColTensor, + paddingMaskTensor, paddingMaskAndDesc); - // Define the greater than equal to comparison descriptor - auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE); + // Define the greater than equal to comparison descriptor + auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE); - // Create a greater than equal to Node. - auto rowGreaterCol_op = binary_pw_op_create(rowIndexTensor, columnIndexTensor, - rowGreaterColTensor, rowGreaterColDesc); + // Create a greater than equal to Node. + auto rowGreaterCol_op = binary_pw_op_create(rowIndexTensor, columnIndexTensor, + rowGreaterColTensor, rowGreaterColDesc); - // Define the and to create causal mask descriptor - auto causalMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND); + // Define the and to create causal mask descriptor + auto causalMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND); - // Create a causal Mask Node. - auto causalMaskAnd_op = binary_pw_op_create(paddingMaskTensor, rowGreaterColTensor, - causalMaskTensor, causalMaskAndDesc); + // Create a causal Mask Node. + auto causalMaskAnd_op = binary_pw_op_create(paddingMaskTensor, rowGreaterColTensor, + causalMaskTensor, causalMaskAndDesc); - /////////////////// Apply the mask ////////////////////////// + /////////////////// Apply the mask ////////////////////////// - auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - ? std::move(causalMaskTensor) - : std::move(paddingMaskTensor); + auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) + ? std::move(causalMaskTensor) + : std::move(paddingMaskTensor); - // Define the binary select to perform masking descriptor - auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); + // Define the binary select to perform masking descriptor + auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); - // Create a binary select Node. - auto mask_op = ternary_pw_op_create(prevBlockOutputTensor, maskValTensor, maskTensor, - maskOutputTensor, maskDesc); + // Create a binary select Node. + auto mask_op = ternary_pw_op_create(prevBlockOutputTensor, maskValTensor, maskTensor, + maskOutputTensor, maskDesc); - ops.push_back(std::move(genIndexRow_op)); - ops.push_back(std::move(genIndexColumn_op)); - ops.push_back(std::move(lessThanRow_op)); - ops.push_back(std::move(lessThanCol_op)); - ops.push_back(std::move(paddingMaskAnd_op)); - if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) { - ops.push_back(std::move(rowGreaterCol_op)); - ops.push_back(std::move(causalMaskAnd_op)); - } - ops.push_back(std::move(mask_op)); + ops.push_back(std::move(genIndexRow_op)); + ops.push_back(std::move(genIndexColumn_op)); + ops.push_back(std::move(lessThanRow_op)); + ops.push_back(std::move(lessThanCol_op)); + ops.push_back(std::move(paddingMaskAnd_op)); + if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) { + ops.push_back(std::move(rowGreaterCol_op)); + ops.push_back(std::move(causalMaskAnd_op)); + } + ops.push_back(std::move(mask_op)); - return maskOutputTensor; + return maskOutputTensor; } static cudnn_frontend::Tensor createSoftmaxForward( @@ -326,102 +321,100 @@ static cudnn_frontend::Tensor createSoftmaxForward( bool enable_dropout, bool softmax_output_virtual, cudnnDataType_t tensorType, std::vector &ops, cudnn_frontend::Tensor const &prevBlockOutputTensor) { - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t afterReduction_dim[4] = {b, h, s_q, 1}; - int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; - - cudnnDataType_t softmaxOutputType = enable_dropout ? CUDNN_DATA_FLOAT : tensorType; - uint64_t softmaxOutputName = softmax_output_virtual ? VIRTUAL_ID + 154 : S_ID; - - // max (x) - auto afterMaxReductionTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim, afterReduction_stride, - true, false); // is virtual - // x - max(x) - auto afterSubtractionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // e^(x - max(x)) - auto afterExponentTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual; - // sum (e^(x - max(x))) - auto afterAddReductionTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim, afterReduction_stride, - true, false); // is virtual - // divide (e/ sum(e)) - - auto reorder_type = - cudnn_frontend::TensorReordering_t::F16x16; - - auto afterDivisionTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(softmaxOutputName) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(softmaxOutputType) - .setVirtual(softmax_output_virtual) - .setByValue(false) - .setReorderType(reorder_type) - .build(); - - // Define the reduction descriptor - auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) - .build(); - - // Create a reduction max Node. - auto reductionMax_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(prevBlockOutputTensor) - .setyDesc(afterMaxReductionTensor) - .setreductionDesc(reductionMaxDesc) - .build(); - - // Define the subtract descriptor - auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - - // Create a subtract Node. - auto subtract_op = binary_pw_op_create(prevBlockOutputTensor, afterMaxReductionTensor, - afterSubtractionTensor, subtractDesc); - - // Define the exponent descriptor - auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); - - // Create a exponent Node. - auto exponent_op = - unary_pw_op_create(afterSubtractionTensor, afterExponentTensor, exponentDesc); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add Node. - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(afterExponentTensor) - .setyDesc(afterAddReductionTensor) - .setreductionDesc(reductionAddDesc) - .build(); - - // Define the division descriptor - auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV); - - // Create a subtract Node. - auto division_op = binary_pw_op_create(afterExponentTensor, afterAddReductionTensor, - afterDivisionTensor, divisionDesc); - - ops.push_back(std::move(reductionMax_op)); - ops.push_back(std::move(subtract_op)); - ops.push_back(std::move(exponent_op)); - ops.push_back(std::move(reductionAdd_op)); - ops.push_back(std::move(division_op)); - - return afterDivisionTensor; + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t afterReduction_dim[4] = {b, h, s_q, 1}; + int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; + + cudnnDataType_t softmaxOutputType = enable_dropout ? CUDNN_DATA_FLOAT : tensorType; + uint64_t softmaxOutputName = softmax_output_virtual ? VIRTUAL_ID + 154 : S_ID; + + // max (x) + auto afterMaxReductionTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim, afterReduction_stride, + true, false); // is virtual + // x - max(x) + auto afterSubtractionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + // e^(x - max(x)) + auto afterExponentTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual; + // sum (e^(x - max(x))) + auto afterAddReductionTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim, afterReduction_stride, + true, false); // is virtual + // divide (e/ sum(e)) + + auto reorder_type = cudnn_frontend::TensorReordering_t::F16x16; + + auto afterDivisionTensor = + cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setId(softmaxOutputName) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(softmaxOutputType) + .setVirtual(softmax_output_virtual) + .setByValue(false) + .setReorderType(reorder_type) + .build(); + + // Define the reduction descriptor + auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) + .build(); + + // Create a reduction max Node. + auto reductionMax_op = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(prevBlockOutputTensor) + .setyDesc(afterMaxReductionTensor) + .setreductionDesc(reductionMaxDesc) + .build(); + + // Define the subtract descriptor + auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + + // Create a subtract Node. + auto subtract_op = binary_pw_op_create(prevBlockOutputTensor, afterMaxReductionTensor, + afterSubtractionTensor, subtractDesc); + + // Define the exponent descriptor + auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); + + // Create a exponent Node. + auto exponent_op = unary_pw_op_create(afterSubtractionTensor, afterExponentTensor, exponentDesc); + + // Define the reduction descriptor + auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + + // Create a reduction add Node. + auto reductionAdd_op = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(afterExponentTensor) + .setyDesc(afterAddReductionTensor) + .setreductionDesc(reductionAddDesc) + .build(); + + // Define the division descriptor + auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV); + + // Create a subtract Node. + auto division_op = binary_pw_op_create(afterExponentTensor, afterAddReductionTensor, + afterDivisionTensor, divisionDesc); + + ops.push_back(std::move(reductionMax_op)); + ops.push_back(std::move(subtract_op)); + ops.push_back(std::move(exponent_op)); + ops.push_back(std::move(reductionAdd_op)); + ops.push_back(std::move(division_op)); + + return afterDivisionTensor; } static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, @@ -429,124 +422,123 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i cudnnDataType_t tensorType, std::vector &ops, cudnn_frontend::Tensor const &prevBlockOutputTensor) { - NVTE_CHECK(ops.size() != 0, "Dropout DAG constructed incorrectly as the first one"); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // mask for the dropout - auto dropoutMaskTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual + NVTE_CHECK(ops.size() != 0, "Dropout DAG constructed incorrectly as the first one"); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + // mask for the dropout + auto dropoutMaskTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + + auto reorder_type = cudnn_frontend::TensorReordering_t::F16x16; + + // after dropout tensor + auto afterDropoutTensor = + cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setId(S_ID) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(tensorType) + .setVirtual(false) + .setByValue(false) + .setReorderType(reorder_type) + .build(); + // scale after dropout + auto scaleDropoutTensor = + tensor_create(tensorType, DROPOUT_CONST_ID, scale_dim, scale_stride, false, + true); // is by value + // after Scale + auto afterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 201, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual - auto reorder_type = - cudnn_frontend::TensorReordering_t::F16x16; - - // after dropout tensor - auto afterDropoutTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(S_ID) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(tensorType) - .setVirtual(false) - .setByValue(false) - .setReorderType(reorder_type) - .build(); - // scale after dropout - auto scaleDropoutTensor = - tensor_create(tensorType, DROPOUT_CONST_ID, scale_dim, scale_stride, false, - true); // is by value - // after Scale - auto afterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 201, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the reduction descriptor - auto rngDesc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - probability) - .build(); - - auto dropoutSeed = - tensor_create(CUDNN_DATA_INT64, DROPOUT_SEED_ID, scale_dim, scale_stride, false, false); - auto dropoutOffset = - tensor_create(CUDNN_DATA_INT64, DROPOUT_OFFSET_ID, scale_dim, scale_stride, false, false); - - // Create a rng Node. - auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(dropoutMaskTensor) - .setSeedDesc(dropoutSeed) - .setOffsetDesc(dropoutOffset) - .setRngDesc(rngDesc) - .build(); + // Define the reduction descriptor + auto rngDesc = cudnn_frontend::RngDescBuilder() + .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) + .setBernoulliDistProbability(1.0 - probability) + .build(); + + auto dropoutSeed = + tensor_create(CUDNN_DATA_INT64, DROPOUT_SEED_ID, scale_dim, scale_stride, false, false); + auto dropoutOffset = + tensor_create(CUDNN_DATA_INT64, DROPOUT_OFFSET_ID, scale_dim, scale_stride, false, false); + + // Create a rng Node. + auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) + .setyDesc(dropoutMaskTensor) + .setSeedDesc(dropoutSeed) + .setOffsetDesc(dropoutOffset) + .setRngDesc(rngDesc) + .build(); - // Define the multiply mask descriptor - auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + // Define the multiply mask descriptor + auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - // Create a multiply mask Node. - auto maskMul_op = binary_pw_op_create(prevBlockOutputTensor, dropoutMaskTensor, - afterDropoutTensor, maskMulDesc); + // Create a multiply mask Node. + auto maskMul_op = binary_pw_op_create(prevBlockOutputTensor, dropoutMaskTensor, + afterDropoutTensor, maskMulDesc); - // Define the multiply scale descriptor - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + // Define the multiply scale descriptor + auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - // Create a multiply mask Node. - auto scaleMul_op = - binary_pw_op_create(afterDropoutTensor, scaleDropoutTensor, afterScaleTensor, scaleMulDesc); + // Create a multiply mask Node. + auto scaleMul_op = + binary_pw_op_create(afterDropoutTensor, scaleDropoutTensor, afterScaleTensor, scaleMulDesc); - ops.push_back(std::move(rng_op)); - ops.push_back(std::move(maskMul_op)); - ops.push_back(std::move(scaleMul_op)); + ops.push_back(std::move(rng_op)); + ops.push_back(std::move(maskMul_op)); + ops.push_back(std::move(scaleMul_op)); - return afterScaleTensor; + return afterScaleTensor; } static void createBMM2(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, cudnnDataType_t tensorType, std::vector &ops, cudnn_frontend::Tensor const &prevBlockOutputTensor) { - NVTE_CHECK(ops.size() != 0, "BMM2 op constructed incorrectly as the first one"); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - int64_t v_dim[4] = {b, h, s_kv, d}; - int64_t v_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); - // second GEMM output - auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false); - - // Define the matmul 2 desc - // set padding value optionally to 0 for writing zeros to O tensor (if not set, old behaviour) - auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); + NVTE_CHECK(ops.size() != 0, "BMM2 op constructed incorrectly as the first one"); + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + + int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + + int64_t o_dim[4] = {b, h, s_q, d}; + int64_t o_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + + auto seqlenQTensor = + tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + auto seqlenKTensor = + tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); + // second GEMM output + auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false); + + // Define the matmul 2 desc + // set padding value optionally to 0 for writing zeros to O tensor (if not set, old behaviour) + auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) + .build(); + + // Create a matmul 2 Node + auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(prevBlockOutputTensor) + .setbMatDesc(vTensor) + .setcMatDesc(oTensor) + .setmOverrideDesc(seqlenQTensor) + .setkOverrideDesc(seqlenKTensor) + .setmatmulDesc(matmul_2_Desc) + .build(); - // Create a matmul 2 Node - auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(prevBlockOutputTensor) - .setbMatDesc(vTensor) - .setcMatDesc(oTensor) - .setmOverrideDesc(seqlenQTensor) - .setkOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_2_Desc) - .build(); - - ops.push_back(std::move(matmul_op2)); + ops.push_back(std::move(matmul_op2)); } static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, @@ -555,75 +547,75 @@ static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_ std::vector &ops, cudnn_frontend::Tensor const &yTensor, cudnn_frontend::Tensor const &dyTensor) { - NVTE_CHECK(ops.size() != 0, "Softmax backward constructed incorrectly as the first one"); - - int64_t p_dim[4] = {b, h, s_q, s_kv}; - int64_t p_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t p_reduction_dim[4] = {b, h, s_q, 1}; - int64_t p_reduction_stride[4]; - - p_reduction_stride[3] = 1; - p_reduction_stride[2] = 1; - p_reduction_stride[1] = s_q; - p_reduction_stride[0] = s_q * h; - - int64_t const_dim[4] = {1, 1, 1, 1}; - int64_t const_stride[4] = {1, 1, 1, 1}; - - // creating all tensors - auto softmaxScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, const_dim, const_stride, false, true); - auto dyMulYTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 250, p_dim, p_stride, true, false); - auto dxAfterReductionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 251, p_reduction_dim, - p_reduction_stride, true, false); - auto dxAfterSubtractionTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 252, p_dim, p_stride, true, false); - auto dxUnscaleTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 253, p_dim, p_stride, true, false); - auto dxTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 254, p_dim, p_stride, true, false); - - // creating all ops - // mul (y * dy) - auto mul_1_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_1_op = binary_pw_op_create(yTensor, dyTensor, dyMulYTensor, mul_1_desc); - - // reduction add sum (y * dy) - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(dyMulYTensor) - .setyDesc(dxAfterReductionTensor) - .setreductionDesc(reductionAddDesc) - .build(); - - // subtraction (dy - sum(y * dy)) - auto sub_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - auto sub_0_op = - binary_pw_op_create(dyTensor, dxAfterReductionTensor, dxAfterSubtractionTensor, sub_0_desc); - - // mul (y * (dy - sum(y * dy))) - auto mul_2_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_2_op = - binary_pw_op_create(yTensor, dxAfterSubtractionTensor, dxUnscaleTensor, mul_2_desc); - - // mul (scale * dx) - auto mul_3_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_3_op = binary_pw_op_create(dxUnscaleTensor, softmaxScaleTensor, dxTensor, mul_3_desc); - - ops.push_back(std::move(mul_1_op)); - ops.push_back(std::move(reductionAdd_op)); - ops.push_back(std::move(sub_0_op)); - ops.push_back(std::move(mul_2_op)); - ops.push_back(std::move(mul_3_op)); - - return dxTensor; + NVTE_CHECK(ops.size() != 0, "Softmax backward constructed incorrectly as the first one"); + + int64_t p_dim[4] = {b, h, s_q, s_kv}; + int64_t p_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + int64_t p_reduction_dim[4] = {b, h, s_q, 1}; + int64_t p_reduction_stride[4]; + + p_reduction_stride[3] = 1; + p_reduction_stride[2] = 1; + p_reduction_stride[1] = s_q; + p_reduction_stride[0] = s_q * h; + + int64_t const_dim[4] = {1, 1, 1, 1}; + int64_t const_stride[4] = {1, 1, 1, 1}; + + // creating all tensors + auto softmaxScaleTensor = + tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, const_dim, const_stride, false, true); + auto dyMulYTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 250, p_dim, p_stride, true, false); + auto dxAfterReductionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 251, p_reduction_dim, + p_reduction_stride, true, false); + auto dxAfterSubtractionTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 252, p_dim, p_stride, true, false); + auto dxUnscaleTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 253, p_dim, p_stride, true, false); + auto dxTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 254, p_dim, p_stride, true, false); + + // creating all ops + // mul (y * dy) + auto mul_1_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + auto mul_1_op = binary_pw_op_create(yTensor, dyTensor, dyMulYTensor, mul_1_desc); + + // reduction add sum (y * dy) + auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + + auto reductionAdd_op = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(dyMulYTensor) + .setyDesc(dxAfterReductionTensor) + .setreductionDesc(reductionAddDesc) + .build(); + + // subtraction (dy - sum(y * dy)) + auto sub_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + auto sub_0_op = + binary_pw_op_create(dyTensor, dxAfterReductionTensor, dxAfterSubtractionTensor, sub_0_desc); + + // mul (y * (dy - sum(y * dy))) + auto mul_2_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + auto mul_2_op = + binary_pw_op_create(yTensor, dxAfterSubtractionTensor, dxUnscaleTensor, mul_2_desc); + + // mul (scale * dx) + auto mul_3_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + auto mul_3_op = binary_pw_op_create(dxUnscaleTensor, softmaxScaleTensor, dxTensor, mul_3_desc); + + ops.push_back(std::move(mul_1_op)); + ops.push_back(std::move(reductionAdd_op)); + ops.push_back(std::move(sub_0_op)); + ops.push_back(std::move(mul_2_op)); + ops.push_back(std::move(mul_3_op)); + + return dxTensor; } void fused_attn_max_512_fwd_impl( @@ -633,200 +625,198 @@ void fused_attn_max_512_fwd_impl( void *devPtrS, void *devPtrO, void *devPtrBias, void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size, cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { - try { - FADescriptor descriptor{b, h, - s_q, s_kv, - d, scaling_factor, - is_training, dropout_probability, - layout, bias_type, - mask_type, tensorType, - false}; - - using CacheType = std::map; - static thread_local CacheType fmha_fprop_cache; - - // softmax auxiliary is only used in the training mode - bool enable_dropout = is_training && (dropout_probability != 0.0f); - - // two conditions that make softmax auxiliary in virtual - // 1. inference mode (not is_training) - // 2. dropout enabled: the auxiliary becomes the dropout output - bool softmax_output_virtual = !is_training || enable_dropout; - - // Get plan from cache if cache is available, otherwise create one - auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { - // if hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto plan = it->second; - return plan; - } - - // otherwise, build the op_graph and the plan. Then update cache - std::vector all_ops; - std::vector ops; - - createScale(b, h, s_q, s_kv, d, layout, tensorType, ops); - - // if bias, we need to memset the S buffer to correctly computate dbias - // WAR: causal_mask without bias needs memset the S buffer - // inference mode doesn't need the S auxiliary - auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) || - (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) && is_training; - std::shared_ptr maskInput; - auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); - - NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS, - "NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS has not been implemented."); - - if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { - auto bias_output = createBias(b, h, s_q, s_kv, d, layout, - tensorType, ops, bmm1_output); - maskInput = std::make_shared(std::move(bias_output)); - } - if (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) { - maskInput = std::make_shared(std::move(bmm1_output)); - } - - auto mask_output = createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops, - *maskInput.get(), false); - - NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0."); - - auto softmax_output = - createSoftmaxForward(b, h, s_q, s_kv, d, layout, enable_dropout, - softmax_output_virtual, tensorType, ops, mask_output); - - if (enable_dropout) { - auto dropout_output = createDropout(b, h, s_q, s_kv, d, dropout_probability, - tensorType, ops, softmax_output); - createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, dropout_output); - } else { - createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, softmax_output); - } - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); + try { + FADescriptor descriptor{b, h, + s_q, s_kv, + d, scaling_factor, + is_training, dropout_probability, + layout, bias_type, + mask_type, tensorType, + false}; + + using CacheType = std::map; + static thread_local CacheType fmha_fprop_cache; + + // softmax auxiliary is only used in the training mode + bool enable_dropout = is_training && (dropout_probability != 0.0f); + + // two conditions that make softmax auxiliary in virtual + // 1. inference mode (not is_training) + // 2. dropout enabled: the auxiliary becomes the dropout output + bool softmax_output_virtual = !is_training || enable_dropout; + + // Get plan from cache if cache is available, otherwise create one + auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { + // if hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto plan = it->second; + return plan; + } + + // otherwise, build the op_graph and the plan. Then update cache + std::vector all_ops; + std::vector ops; + + createScale(b, h, s_q, s_kv, d, layout, tensorType, ops); + + // if bias, we need to memset the S buffer to correctly computate dbias + // WAR: causal_mask without bias needs memset the S buffer + // inference mode doesn't need the S auxiliary + auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) || + (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) && + is_training; + std::shared_ptr maskInput; + auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); + + NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS, + "NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS has not been implemented."); + + if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { + auto bias_output = createBias(b, h, s_q, s_kv, d, layout, tensorType, ops, bmm1_output); + maskInput = std::make_shared(std::move(bias_output)); + } + if (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) { + maskInput = std::make_shared(std::move(bmm1_output)); + } + + auto mask_output = createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops, + *maskInput.get(), false); + + NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0."); + + auto softmax_output = + createSoftmaxForward(b, h, s_q, s_kv, d, layout, enable_dropout, softmax_output_virtual, + tensorType, ops, mask_output); + + if (enable_dropout) { + auto dropout_output = + createDropout(b, h, s_q, s_kv, d, dropout_probability, tensorType, ops, softmax_output); + createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, dropout_output); + } else { + createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, softmax_output); + } + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); + } + + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_fprop: No config returned by the heuristics"); + } + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + cache.insert({descriptor, plan}); + return plan; + }; - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_fprop: No config returned by the heuristics"); - } - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; - - auto plan = get_plan(fmha_fprop_cache, descriptor); - - auto plan_workspace_size = plan.getWorkspaceSize(); - - // Exit to request upper level API to allocate memory if needed - if (workspace == nullptr) { - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - // Prepare actual seqlen - constexpr size_t nthreads_per_block = 128; - const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); - cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlenQ), - static_cast(devPtrCuSeqlenKV), - static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); - NVTE_CHECK_CUDA(cudaGetLastError()); - - // change this if you have access to float_min - float negInfinity = -1.0E+10; - float scale_dropout = 1 / (1 - dropout_probability); - - std::set> data_ptrs; - // add all the data pointers to be used in the variant pack - data_ptrs.insert(std::pair(Q_ID, devPtrQ)); - data_ptrs.insert(std::pair(K_ID, devPtrK)); - data_ptrs.insert(std::pair(V_ID, devPtrV)); - data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); - data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); - data_ptrs.insert(std::pair(MASK_VAL_ID, &negInfinity)); - - __half half_cast_scaling_factor{scaling_factor}; - __nv_bfloat16 bfloat_cast_scaling_factor{scaling_factor}; - - if (tensorType == CUDNN_DATA_FLOAT) { - data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); - } else if (tensorType == CUDNN_DATA_HALF) { - data_ptrs.insert(std::pair(S_CONST_ID, &half_cast_scaling_factor)); - } else if (tensorType == CUDNN_DATA_BFLOAT16) { - data_ptrs.insert(std::pair(S_CONST_ID, &bfloat_cast_scaling_factor)); - } else { - NVTE_ERROR("Unsupported tensor type."); - } - - data_ptrs.insert(std::pair(O_ID, devPtrO)); - - if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { - data_ptrs.insert(std::pair(B_ID, devPtrBias)); - } - - // if enable_dropout, S is the result after dropout - // if not enable dropout, S is the result after softmax - if (enable_dropout || !softmax_output_virtual) { - data_ptrs.insert(std::pair(S_ID, devPtrS)); - } - - __half half_cast_scale_dropout{scale_dropout}; - __nv_bfloat16 bfloat16_cast_scale_dropout{scale_dropout}; - - if (enable_dropout) { - // TODO(rewang): make a util func - if (tensorType == CUDNN_DATA_FLOAT) { - data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &scale_dropout)); - } else if (tensorType == CUDNN_DATA_HALF) { - data_ptrs.insert( - std::pair(DROPOUT_CONST_ID, &half_cast_scale_dropout)); - } else if (tensorType == CUDNN_DATA_BFLOAT16) { - data_ptrs.insert( - std::pair(DROPOUT_CONST_ID, &bfloat16_cast_scale_dropout)); - } else { - NVTE_ERROR("Unsupported tensor type."); - } - data_ptrs.insert(std::pair(DROPOUT_SEED_ID, devPtrDropoutSeed)); - data_ptrs.insert(std::pair(DROPOUT_OFFSET_ID, devPtrDropoutOffset)); - } - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace) - .setDataPointers(data_ptrs) - .build(); + auto plan = get_plan(fmha_fprop_cache, descriptor); + + auto plan_workspace_size = plan.getWorkspaceSize(); + + // Exit to request upper level API to allocate memory if needed + if (workspace == nullptr) { + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); + *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + return; + } + + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + // Prepare actual seqlen + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, static_cast(devPtrCuSeqlenQ), + static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenK)); + NVTE_CHECK_CUDA(cudaGetLastError()); + + // change this if you have access to float_min + float negInfinity = -1.0E+10; + float scale_dropout = 1 / (1 - dropout_probability); + + std::set> data_ptrs; + // add all the data pointers to be used in the variant pack + data_ptrs.insert(std::pair(Q_ID, devPtrQ)); + data_ptrs.insert(std::pair(K_ID, devPtrK)); + data_ptrs.insert(std::pair(V_ID, devPtrV)); + data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); + data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); + data_ptrs.insert(std::pair(MASK_VAL_ID, &negInfinity)); + + __half half_cast_scaling_factor{scaling_factor}; + __nv_bfloat16 bfloat_cast_scaling_factor{scaling_factor}; + + if (tensorType == CUDNN_DATA_FLOAT) { + data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); + } else if (tensorType == CUDNN_DATA_HALF) { + data_ptrs.insert(std::pair(S_CONST_ID, &half_cast_scaling_factor)); + } else if (tensorType == CUDNN_DATA_BFLOAT16) { + data_ptrs.insert(std::pair(S_CONST_ID, &bfloat_cast_scaling_factor)); + } else { + NVTE_ERROR("Unsupported tensor type."); + } + + data_ptrs.insert(std::pair(O_ID, devPtrO)); + + if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { + data_ptrs.insert(std::pair(B_ID, devPtrBias)); + } + + // if enable_dropout, S is the result after dropout + // if not enable dropout, S is the result after softmax + if (enable_dropout || !softmax_output_virtual) { + data_ptrs.insert(std::pair(S_ID, devPtrS)); + } - NVTE_CHECK_CUDNN( - cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); + __half half_cast_scale_dropout{scale_dropout}; + __nv_bfloat16 bfloat16_cast_scale_dropout{scale_dropout}; + + if (enable_dropout) { + // TODO(rewang): make a util func + if (tensorType == CUDNN_DATA_FLOAT) { + data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &scale_dropout)); + } else if (tensorType == CUDNN_DATA_HALF) { + data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &half_cast_scale_dropout)); + } else if (tensorType == CUDNN_DATA_BFLOAT16) { + data_ptrs.insert( + std::pair(DROPOUT_CONST_ID, &bfloat16_cast_scale_dropout)); + } else { + NVTE_ERROR("Unsupported tensor type."); + } + data_ptrs.insert(std::pair(DROPOUT_SEED_ID, devPtrDropoutSeed)); + data_ptrs.insert(std::pair(DROPOUT_OFFSET_ID, devPtrDropoutOffset)); } + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace) + .setDataPointers(data_ptrs) + .build(); + + NVTE_CHECK_CUDNN(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); + } catch (cudnn_frontend::cudnnException &e) { + NVTE_ERROR(e.what()); + } } void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, @@ -838,403 +828,387 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void *workspace, size_t *workspace_size, cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { - try { - FADescriptor descriptor{ - b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability, - layout, bias_type, mask_type, tensorType, false}; - - using CacheType = std::map; - static thread_local CacheType fmha_bprop_cache; - - auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { - auto it = cache.find(descriptor); - if (it != cache.end()) { - return it->second; - } - - std::vector all_ops; - std::vector ops; - - // Creates the necessary tensor descriptors - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, s_kv, d}; - int64_t k_stride[4]; - generateMatrixStrides( - b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); // type is correct as K is not transposed - - int64_t v_dim[4] = {b, h, d, s_kv}; - int64_t v_stride[4]; - generateMatrixStrides( - b, h, s_q, s_kv, d, v_stride, layout, - NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); // type is correct as V is transposed - - int64_t p_dim[4] = {b, h, s_q, s_kv}; - int64_t p_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, - NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t p_transpose_dim[4] = {b, h, s_kv, s_q}; - int64_t p_transpose_stride[4]; - p_transpose_stride[0] = p_stride[0]; - p_transpose_stride[1] = p_stride[1]; - p_transpose_stride[2] = p_stride[3]; - p_transpose_stride[3] = p_stride[2]; - - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // inputs to fprop - auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); - auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false); - auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); - auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, - seqlen_stride, false, false); - auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, - seqlen_stride, false, false); - - // gradient of the output - auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false); - - auto reorder_type = - cudnn_frontend::TensorReordering_t::F16x16; - - // activation from fprop - auto pTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, p_dim) - .setStride(4, p_stride) - .setId(S_ID) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(tensorType) - .setVirtual(false) - .setByValue(false) - .setReorderType(reorder_type) - .build(); + try { + FADescriptor descriptor{ + b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability, + layout, bias_type, mask_type, tensorType, false}; + + using CacheType = std::map; + static thread_local CacheType fmha_bprop_cache; + + auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { + auto it = cache.find(descriptor); + if (it != cache.end()) { + return it->second; + } + + std::vector all_ops; + std::vector ops; + + // Creates the necessary tensor descriptors + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_dim[4] = {b, h, s_kv, d}; + int64_t k_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, k_stride, layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); // type is correct as K is not transposed + + int64_t v_dim[4] = {b, h, d, s_kv}; + int64_t v_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, v_stride, layout, + NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); // type is correct as V is transposed + + int64_t p_dim[4] = {b, h, s_q, s_kv}; + int64_t p_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + int64_t p_transpose_dim[4] = {b, h, s_kv, s_q}; + int64_t p_transpose_stride[4]; + p_transpose_stride[0] = p_stride[0]; + p_transpose_stride[1] = p_stride[1]; + p_transpose_stride[2] = p_stride[3]; + p_transpose_stride[3] = p_stride[2]; + + int64_t o_dim[4] = {b, h, s_q, d}; + int64_t o_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + // inputs to fprop + auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); + auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false); + auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); + auto seqlenQTensor = + tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + auto seqlenKTensor = + tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + + // gradient of the output + auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false); + + auto reorder_type = cudnn_frontend::TensorReordering_t::F16x16; + + // activation from fprop + auto pTensor = cudnn_frontend::TensorBuilder() + .setDim(4, p_dim) + .setStride(4, p_stride) + .setId(S_ID) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(tensorType) + .setVirtual(false) + .setByValue(false) + .setReorderType(reorder_type) + .build(); + + // outputs from bprop + auto dqTensor = tensor_create(tensorType, dQ_ID, q_dim, q_stride, false, false); + auto dkTensor = tensor_create(tensorType, dK_ID, k_dim, k_stride, false, false); + auto dvTensor = tensor_create(tensorType, dV_ID, k_dim, k_stride, false, + false); // not transposed therefore k_dim and k_stride + + //////////////////////////////////////////////////////// + // start creating the ops and the intermediate tensors + auto pReshapeTensor = tensor_create(tensorType, VIRTUAL_ID + 300, p_transpose_dim, + p_transpose_stride, true, false); + + // reshape to perform transpose and make pReshape + auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(pTensor) + .setyDesc(pReshapeTensor) + .build(); - // outputs from bprop - auto dqTensor = tensor_create(tensorType, dQ_ID, q_dim, q_stride, false, false); - auto dkTensor = tensor_create(tensorType, dK_ID, k_dim, k_stride, false, false); - auto dvTensor = tensor_create(tensorType, dV_ID, k_dim, k_stride, false, - false); // not transposed therefore k_dim and k_stride - - //////////////////////////////////////////////////////// - // start creating the ops and the intermediate tensors - auto pReshapeTensor = tensor_create(tensorType, VIRTUAL_ID + 300, p_transpose_dim, - p_transpose_stride, true, false); - - // reshape to perform transpose and make pReshape - auto reshape_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(pTensor) - .setyDesc(pReshapeTensor) - .build(); + ops.push_back(std::move(reshape_op)); - ops.push_back(std::move(reshape_op)); - - // scale dropout - auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, DROPOUT_CONST_ID, scale_dim, - scale_stride, false, true); // is by value - auto pAfterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 301, p_transpose_dim, - p_transpose_stride, true, false); - - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto scaleMul_op = binary_pw_op_create(pReshapeTensor, dropoutScaleTensor, - pAfterScaleTensor, scaleMulDesc); - ops.push_back(std::move(scaleMul_op)); - - // perform absolute operation to remove the mask bit - auto pTransposeAfterAbsTensor = tensor_create( - tensorType, VIRTUAL_ID + 302, p_transpose_dim, p_transpose_stride, true, false); - - auto absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS); - auto abs_op = unary_pw_op_create(pAfterScaleTensor, pTransposeAfterAbsTensor, absDesc); - ops.push_back(std::move(abs_op)); - - // matmul to calculate dvTensor - // set padding value optionally to 0 for writing zeros to dV tensor (if not set, old - // behaviour) - auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - auto matmul_op0 = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(pTransposeAfterAbsTensor) - .setbMatDesc(doTensor) - .setcMatDesc(dvTensor) - .setmOverrideDesc(seqlenKTensor) - .setkOverrideDesc(seqlenQTensor) - .setmatmulDesc(matmul_0_Desc) - .build(); + // scale dropout + auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, DROPOUT_CONST_ID, scale_dim, + scale_stride, false, true); // is by value + auto pAfterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 301, p_transpose_dim, + p_transpose_stride, true, false); - ops.push_back(std::move(matmul_op0)); + auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + auto scaleMul_op = + binary_pw_op_create(pReshapeTensor, dropoutScaleTensor, pAfterScaleTensor, scaleMulDesc); + ops.push_back(std::move(scaleMul_op)); - // matmul to calculate dpTensor - auto dpTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 303, p_dim, p_stride, true, false); + // perform absolute operation to remove the mask bit + auto pTransposeAfterAbsTensor = tensor_create(tensorType, VIRTUAL_ID + 302, p_transpose_dim, + p_transpose_stride, true, false); - auto matmul_1_Desc = - cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); + auto absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS); + auto abs_op = unary_pw_op_create(pAfterScaleTensor, pTransposeAfterAbsTensor, absDesc); + ops.push_back(std::move(abs_op)); - auto matmul_op1 = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(doTensor) - .setbMatDesc(vTensor) - .setcMatDesc(dpTensor) - .setmOverrideDesc(seqlenQTensor) - .setnOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_1_Desc) - .build(); + // matmul to calculate dvTensor + // set padding value optionally to 0 for writing zeros to dV tensor (if not set, old + // behaviour) + auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) + .build(); - ops.push_back(std::move(matmul_op1)); - - // mask the values which were dropped in dropout - auto pAbsTensor = - tensor_create(tensorType, VIRTUAL_ID + 304, p_dim, p_stride, true, false); - - auto p_absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS); - auto p_abs_op = unary_pw_op_create(pTensor, pAbsTensor, p_absDesc); - ops.push_back(std::move(p_abs_op)); - - // create the dropout mask - auto zeroTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, scale_dim, scale_stride, - false, true); // is by value - auto dropoutMaskTensor = - tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 305, p_dim, p_stride, true, false); - - auto greater_than_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_GT); - auto greater_than_0_op = - binary_pw_op_create(pTensor, zeroTensor, dropoutMaskTensor, greater_than_0_desc); - ops.push_back(std::move(greater_than_0_op)); - - // scale for the dropout - auto dpAfterScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 306, p_dim, p_stride, true, false); - - auto mul_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_0_op = - binary_pw_op_create(dpTensor, dropoutScaleTensor, dpAfterScaleTensor, mul_0_desc); - ops.push_back(std::move(mul_0_op)); - - // drop the values based on the dropout mask - auto dpAfterDropoutTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 307, p_dim, p_stride, true, false); - - auto selection_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); - auto selection_0_op = - ternary_pw_op_create(dpAfterScaleTensor, zeroTensor, dropoutMaskTensor, - dpAfterDropoutTensor, selection_0_desc); - ops.push_back(std::move(selection_0_op)); - - // softmax backward - auto dsTensor = createSoftmaxBackward(b, h, s_q, s_kv, d, layout, tensorType, ops, - pAbsTensor, dpAfterDropoutTensor); - - // mask - auto dsAfterMaskTensor = - createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops, dsTensor, true); - - // dbias tensor - int64_t dbias_dim[4] = {1, h, s_q, s_kv}; - int64_t dbias_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - auto dBiasTensor = - tensor_create(tensorType, dBias_ID, dbias_dim, dbias_stride, false, false); - - if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { - auto softmaxScaleTensor = tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim, - scale_stride, false, true); - auto softmaxScaleReciprocalTensor = tensor_create( - CUDNN_DATA_FLOAT, VIRTUAL_ID + 401, scale_dim, scale_stride, true, false); - auto dbiasBeforeScaleTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 402, - dbias_dim, dbias_stride, true, false); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add node to compute the dbias - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(dsAfterMaskTensor) - .setyDesc(dbiasBeforeScaleTensor) - .setreductionDesc(reductionAddDesc) - .build(); - ops.push_back(std::move(reductionAdd_op)); - - // take the reciprocal of the scale - auto reciprocal_scale_desc = - pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL); - auto reciprocal_scale_op = unary_pw_op_create( - softmaxScaleTensor, softmaxScaleReciprocalTensor, reciprocal_scale_desc); - ops.push_back(std::move(reciprocal_scale_op)); - - // apply the scale - auto dBias_scale_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto dBias_scale_op = - binary_pw_op_create(dbiasBeforeScaleTensor, softmaxScaleReciprocalTensor, - dBiasTensor, dBias_scale_desc); - ops.push_back(std::move(dBias_scale_op)); - } - - // matmul to calculate dqTensor - // set padding value optionally to 0 for writing zeros to dqTensor (if not set, old - // behaviour) - auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - auto matmul_op2 = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dsAfterMaskTensor) - .setbMatDesc(kTensor) - .setcMatDesc(dqTensor) - .setmOverrideDesc(seqlenQTensor) - .setkOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_2_Desc) - .build(); + auto matmul_op0 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(pTransposeAfterAbsTensor) + .setbMatDesc(doTensor) + .setcMatDesc(dvTensor) + .setmOverrideDesc(seqlenKTensor) + .setkOverrideDesc(seqlenQTensor) + .setmatmulDesc(matmul_0_Desc) + .build(); - ops.push_back(std::move(matmul_op2)); + ops.push_back(std::move(matmul_op0)); - // reshape for transpose of ds - auto dsAfterMaskReshapeTensor = tensor_create( - tensorType, VIRTUAL_ID + 308, p_transpose_dim, p_transpose_stride, true, false); + // matmul to calculate dpTensor + auto dpTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 303, p_dim, p_stride, true, false); - auto reshape_2_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(dsAfterMaskTensor) - .setyDesc(dsAfterMaskReshapeTensor) - .build(); + auto matmul_1_Desc = + cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); - ops.push_back(std::move(reshape_2_op)); - - // matmul to calculate dkTensor - // set padding value optionally to 0 for writing zeros to dktensor (if not set, old - // behaviour) - auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - auto matmul_op3 = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dsAfterMaskReshapeTensor) - .setbMatDesc(qTensor) - .setcMatDesc(dkTensor) - .setmOverrideDesc(seqlenKTensor) - .setkOverrideDesc(seqlenQTensor) - .setmatmulDesc(matmul_3_Desc) - .build(); + auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(doTensor) + .setbMatDesc(vTensor) + .setcMatDesc(dpTensor) + .setmOverrideDesc(seqlenQTensor) + .setnOverrideDesc(seqlenKTensor) + .setmatmulDesc(matmul_1_Desc) + .build(); - ops.push_back(std::move(matmul_op3)); + ops.push_back(std::move(matmul_op1)); + + // mask the values which were dropped in dropout + auto pAbsTensor = tensor_create(tensorType, VIRTUAL_ID + 304, p_dim, p_stride, true, false); + + auto p_absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS); + auto p_abs_op = unary_pw_op_create(pTensor, pAbsTensor, p_absDesc); + ops.push_back(std::move(p_abs_op)); + + // create the dropout mask + auto zeroTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, scale_dim, scale_stride, false, + true); // is by value + auto dropoutMaskTensor = + tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 305, p_dim, p_stride, true, false); + + auto greater_than_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_GT); + auto greater_than_0_op = + binary_pw_op_create(pTensor, zeroTensor, dropoutMaskTensor, greater_than_0_desc); + ops.push_back(std::move(greater_than_0_op)); + + // scale for the dropout + auto dpAfterScaleTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 306, p_dim, p_stride, true, false); + + auto mul_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + auto mul_0_op = + binary_pw_op_create(dpTensor, dropoutScaleTensor, dpAfterScaleTensor, mul_0_desc); + ops.push_back(std::move(mul_0_op)); + + // drop the values based on the dropout mask + auto dpAfterDropoutTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 307, p_dim, p_stride, true, false); + + auto selection_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); + auto selection_0_op = ternary_pw_op_create(dpAfterScaleTensor, zeroTensor, dropoutMaskTensor, + dpAfterDropoutTensor, selection_0_desc); + ops.push_back(std::move(selection_0_op)); + + // softmax backward + auto dsTensor = createSoftmaxBackward(b, h, s_q, s_kv, d, layout, tensorType, ops, pAbsTensor, + dpAfterDropoutTensor); + + // mask + auto dsAfterMaskTensor = + createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops, dsTensor, true); + + // dbias tensor + int64_t dbias_dim[4] = {1, h, s_q, s_kv}; + int64_t dbias_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + auto dBiasTensor = tensor_create(tensorType, dBias_ID, dbias_dim, dbias_stride, false, false); + + if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { + auto softmaxScaleTensor = + tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim, scale_stride, false, true); + auto softmaxScaleReciprocalTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 401, scale_dim, scale_stride, true, false); + auto dbiasBeforeScaleTensor = + tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 402, dbias_dim, dbias_stride, true, false); + + // Define the reduction descriptor + auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + + // Create a reduction add node to compute the dbias + auto reductionAdd_op = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(dsAfterMaskTensor) + .setyDesc(dbiasBeforeScaleTensor) + .setreductionDesc(reductionAddDesc) + .build(); + ops.push_back(std::move(reductionAdd_op)); + + // take the reciprocal of the scale + auto reciprocal_scale_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL); + auto reciprocal_scale_op = unary_pw_op_create( + softmaxScaleTensor, softmaxScaleReciprocalTensor, reciprocal_scale_desc); + ops.push_back(std::move(reciprocal_scale_op)); + + // apply the scale + auto dBias_scale_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + auto dBias_scale_op = binary_pw_op_create( + dbiasBeforeScaleTensor, softmaxScaleReciprocalTensor, dBiasTensor, dBias_scale_desc); + ops.push_back(std::move(dBias_scale_op)); + } + + // matmul to calculate dqTensor + // set padding value optionally to 0 for writing zeros to dqTensor (if not set, old + // behaviour) + auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) + .build(); - ///////////////////////////////////////////////////////////////// + auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dsAfterMaskTensor) + .setbMatDesc(kTensor) + .setcMatDesc(dqTensor) + .setmOverrideDesc(seqlenQTensor) + .setkOverrideDesc(seqlenKTensor) + .setmatmulDesc(matmul_2_Desc) + .build(); - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } + ops.push_back(std::move(matmul_op2)); - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); + // reshape for transpose of ds + auto dsAfterMaskReshapeTensor = tensor_create(tensorType, VIRTUAL_ID + 308, p_transpose_dim, + p_transpose_stride, true, false); - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); + auto reshape_2_op = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(dsAfterMaskTensor) + .setyDesc(dsAfterMaskReshapeTensor) + .build(); - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_bprop: No config returned by the heuristics"); - } + ops.push_back(std::move(reshape_2_op)); - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; - - auto plan = get_plan(fmha_bprop_cache, descriptor); - - auto plan_workspace_size = plan.getWorkspaceSize(); - - // Exit to request upper level API to allocate memory if needed - if (workspace == nullptr) { - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - constexpr size_t nthreads_per_block = 128; - const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); - cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlenQ), - static_cast(devPtrCuSeqlenKV), - static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); - NVTE_CHECK_CUDA(cudaGetLastError()); - - std::set> data_ptrs; - // add all the data pointers to be used in the variant pack - data_ptrs.insert(std::pair(dQ_ID, devPtrdQ)); - data_ptrs.insert(std::pair(dK_ID, devPtrdK)); - data_ptrs.insert(std::pair(dV_ID, devPtrdV)); - - data_ptrs.insert(std::pair(Q_ID, devPtrQ)); - data_ptrs.insert(std::pair(K_ID, devPtrK)); - data_ptrs.insert(std::pair(V_ID, devPtrV)); - data_ptrs.insert(std::pair(S_ID, devPtrS)); - data_ptrs.insert(std::pair(dO_ID, devPtrdO)); - data_ptrs.insert(std::pair(dS_ID, devPtrdS)); - data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); - data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); - - if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { - data_ptrs.insert(std::pair(dBias_ID, devPtrdBias)); - } - - float zeroVal = 0.0f; - float dropoutScale = 1.0f / (1.0f - dropout_probability); - - data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &dropoutScale)); - data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); - data_ptrs.insert(std::pair(MASK_VAL_ID, &zeroVal)); - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace) - .setDataPointers(data_ptrs) + // matmul to calculate dkTensor + // set padding value optionally to 0 for writing zeros to dktensor (if not set, old + // behaviour) + auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) .build(); - NVTE_CHECK_CUDNN( - cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); + auto matmul_op3 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dsAfterMaskReshapeTensor) + .setbMatDesc(qTensor) + .setcMatDesc(dkTensor) + .setmOverrideDesc(seqlenKTensor) + .setkOverrideDesc(seqlenQTensor) + .setmatmulDesc(matmul_3_Desc) + .build(); + + ops.push_back(std::move(matmul_op3)); + + ///////////////////////////////////////////////////////////////// + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); + } + + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_bprop: No config returned by the heuristics"); + } + + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + cache.insert({descriptor, plan}); + return plan; + }; + + auto plan = get_plan(fmha_bprop_cache, descriptor); + + auto plan_workspace_size = plan.getWorkspaceSize(); + + // Exit to request upper level API to allocate memory if needed + if (workspace == nullptr) { + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); + *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + return; } + + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, static_cast(devPtrCuSeqlenQ), + static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenK)); + NVTE_CHECK_CUDA(cudaGetLastError()); + + std::set> data_ptrs; + // add all the data pointers to be used in the variant pack + data_ptrs.insert(std::pair(dQ_ID, devPtrdQ)); + data_ptrs.insert(std::pair(dK_ID, devPtrdK)); + data_ptrs.insert(std::pair(dV_ID, devPtrdV)); + + data_ptrs.insert(std::pair(Q_ID, devPtrQ)); + data_ptrs.insert(std::pair(K_ID, devPtrK)); + data_ptrs.insert(std::pair(V_ID, devPtrV)); + data_ptrs.insert(std::pair(S_ID, devPtrS)); + data_ptrs.insert(std::pair(dO_ID, devPtrdO)); + data_ptrs.insert(std::pair(dS_ID, devPtrdS)); + data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); + data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); + + if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { + data_ptrs.insert(std::pair(dBias_ID, devPtrdBias)); + } + + float zeroVal = 0.0f; + float dropoutScale = 1.0f / (1.0f - dropout_probability); + + data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &dropoutScale)); + data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); + data_ptrs.insert(std::pair(MASK_VAL_ID, &zeroVal)); + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace) + .setDataPointers(data_ptrs) + .build(); + + NVTE_CHECK_CUDNN(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); + } catch (cudnn_frontend::cudnnException &e) { + NVTE_ERROR(e.what()); + } } } // namespace fused_attn @@ -1246,65 +1220,65 @@ void fused_attn_max_512_fwd_qkvpacked( NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - const auto stride = 2 * num_head * head_dim; - - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = static_cast(input_Bias->data.dptr); - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; - output_S->data.dtype = input_QKV->data.dtype; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrCuSeqlen = cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - const DType QKV_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, - workspace->data.dptr, &workspace_size, get_cudnn_dtype(QKV_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); + using namespace transformer_engine; + + // QKV shape is [b, s, 3, h, d] + void *devPtrQKV = input_QKV->data.dptr; + const auto stride = 2 * num_head * head_dim; + + void *devPtrQ = static_cast(devPtrQKV); + void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); + + void *devPtrBias = static_cast(input_Bias->data.dptr); + + void *devPtrO = output_O->data.dptr; + + void *devPtrS = nullptr; + + if (Aux_CTX_Tensors->size == 0) { + Aux_CTX_Tensors->size = 1; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; + output_S->data.dtype = input_QKV->data.dtype; + } else if (Aux_CTX_Tensors->size == 1) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void *devPtrCuSeqlen = cu_seqlens->data.dptr; + + const DType rng_state_type = rng_state->data.dtype; + NVTE_CHECK(rng_state_type == DType::kInt64); + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + static_cast(static_cast(rng_state->data.dptr) + 1); + + const DType QKV_type = input_QKV->data.dtype; + size_t workspace_size = 0; + + fused_attn_max_512_fwd_impl( + batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, + devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, + &workspace_size, get_cudnn_dtype(QKV_type), stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } } void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, @@ -1316,202 +1290,201 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS, - "NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512."); - - // Q shape is [b, s, h, d] - void *devPtrQ = input_Q->data.dptr; - - // KV shape is [b, s, 2, h, d] - const auto stride = 2 * num_head * head_dim; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrBias = input_Bias->data.dptr; - - void *devPtrO = output_O->data.dptr; + using namespace transformer_engine; + + NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || + bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS, + "NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512."); + + // Q shape is [b, s, h, d] + void *devPtrQ = input_Q->data.dptr; + + // KV shape is [b, s, 2, h, d] + const auto stride = 2 * num_head * head_dim; + void *devPtrK = input_KV->data.dptr; + void *devPtrV = static_cast(static_cast(devPtrK) + stride); + + void *devPtrBias = input_Bias->data.dptr; + + void *devPtrO = output_O->data.dptr; + + void *devPtrS = nullptr; + + const DType q_type = input_Q->data.dtype; + const DType kv_type = input_KV->data.dtype; + NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); + + if (Aux_CTX_Tensors->size == 0) { + Aux_CTX_Tensors->size = 1; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; + output_S->data.dtype = q_type; + } else if (Aux_CTX_Tensors->size == 1) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void *devQCuSeqlen = q_cu_seqlens->data.dptr; + void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; + + const DType rng_state_type = rng_state->data.dtype; + NVTE_CHECK(rng_state_type == DType::kInt64); + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + static_cast(static_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_max_512_fwd_impl( + batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, + devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, + &workspace_size, get_cudnn_dtype(q_type), stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } +} +void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t head_dim, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, + const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + + void *devPtrQ = input_Q->data.dptr; + void *devPtrK = input_K->data.dptr; + void *devPtrV = input_V->data.dptr; + + void *devPtrBias = input_Bias->data.dptr; + + void *devPtrO = output_O->data.dptr; + + void *devPtrS = nullptr; + + const DType q_type = input_Q->data.dtype; + const DType kv_type = input_K->data.dtype; + NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); + + if (Aux_CTX_Tensors->size == 0) { + Aux_CTX_Tensors->size = 1; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; + output_S->data.dtype = q_type; + } else if (Aux_CTX_Tensors->size == 1) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void *devQCuSeqlen = q_cu_seqlens->data.dptr; + void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; + + const DType rng_state_type = rng_state->data.dtype; + NVTE_CHECK(rng_state_type == DType::kInt64); + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + static_cast(static_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_max_512_fwd_impl( + batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, + devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, + &workspace_size, get_cudnn_dtype(q_type), stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } +} - void *devPtrS = nullptr; +void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, + size_t head_dim, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor *input_QKV, + const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, + Tensor *output_dBias, const Tensor *cu_seqlens, + Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; - const DType q_type = input_Q->data.dtype; - const DType kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); + // QKV shape is [b, s, 3, h, d] + void *devPtrQKV = input_QKV->data.dptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; - output_S->data.dtype = q_type; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } + auto stride = 2 * num_head * head_dim; + void *devPtrQ = devPtrQKV; + void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void *devQCuSeqlen = q_cu_seqlens->data.dptr; - void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} -void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; + void *devPtrdO = input_dO->data.dptr; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_K->data.dptr; - void *devPtrV = input_V->data.dptr; + // dQKV shape is [b, s, 3, h, d] + void *devPtrdQKV = output_dQKV->data.dptr; + void *devPtrdQ = devPtrdQKV; + void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); + void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - void *devPtrBias = input_Bias->data.dptr; + void *devPtrdBias = output_dBias->data.dptr; - void *devPtrO = output_O->data.dptr; + void *devPtrS = output_S->data.dptr; - void *devPtrS = nullptr; + // devPtrdS reuses the memory of devPtrS + void *devPtrdS = devPtrS; - const DType q_type = input_Q->data.dtype; - const DType kv_type = input_K->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); + void *devPtrCuSeqlens = cu_seqlens->data.dptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; - output_S->data.dtype = q_type; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } + const auto qkv_type = input_QKV->data.dtype; + size_t workspace_size = 0; - void *devQCuSeqlen = q_cu_seqlens->data.dptr; - void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} + fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale, + p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK, + devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS, + devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr, + &workspace_size, get_cudnn_dtype(qkv_type), stream, handle); -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - - auto stride = 2 * num_head * head_dim; - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQKV shape is [b, s, 3, h, d] - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - - const auto qkv_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK, - devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS, - devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(qkv_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } } void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, @@ -1519,119 +1492,117 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, + const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dKV, Tensor *output_dBias, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // Q shape is [b, s, h, d] - // KV shape is [b, s, 2, h, d] - auto stride = 2 * num_head * head_dim; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQ shape is [b, s, h, d] - // dKV shape is [b, s, 2, h, d] - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dKV->data.dptr; - void *devPtrdV = static_cast(static_cast(devPtrdK) + stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; - void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; - - const auto q_type = input_Q->data.dtype; - const auto kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} -void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, - const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, - Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; + using namespace transformer_engine; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_K->data.dptr; - void *devPtrV = input_V->data.dptr; + // Q shape is [b, s, h, d] + // KV shape is [b, s, 2, h, d] + auto stride = 2 * num_head * head_dim; + void *devPtrQ = input_Q->data.dptr; + void *devPtrK = input_KV->data.dptr; + void *devPtrV = static_cast(static_cast(devPtrK) + stride); - void *devPtrdO = input_dO->data.dptr; + void *devPtrdO = input_dO->data.dptr; - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dK->data.dptr; - void *devPtrdV = output_dV->data.dptr; + // dQ shape is [b, s, h, d] + // dKV shape is [b, s, 2, h, d] + void *devPtrdQ = output_dQ->data.dptr; + void *devPtrdK = output_dKV->data.dptr; + void *devPtrdV = static_cast(static_cast(devPtrdK) + stride); - void *devPtrdBias = output_dBias->data.dptr; + void *devPtrdBias = output_dBias->data.dptr; - void *devPtrS = output_S->data.dptr; + void *devPtrS = output_S->data.dptr; - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; + // devPtrdS reuses the memory of devPtrS + void *devPtrdS = devPtrS; - void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; - void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; + void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; + void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; - const auto q_type = input_Q->data.dtype; - const auto kv_type = input_K->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - size_t workspace_size = 0; + const auto q_type = input_Q->data.dtype; + const auto kv_type = input_KV->data.dtype; + NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); + size_t workspace_size = 0; - fused_attn_max_512_bwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); + fused_attn_max_512_bwd_impl( + batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, + mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, + devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, + &workspace_size, get_cudnn_dtype(q_type), stream, handle); - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } +} +void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_dO, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, + Tensor *output_dBias, const Tensor *q_cu_seqlens, + const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + + void *devPtrQ = input_Q->data.dptr; + void *devPtrK = input_K->data.dptr; + void *devPtrV = input_V->data.dptr; + + void *devPtrdO = input_dO->data.dptr; + + void *devPtrdQ = output_dQ->data.dptr; + void *devPtrdK = output_dK->data.dptr; + void *devPtrdV = output_dV->data.dptr; + + void *devPtrdBias = output_dBias->data.dptr; + + void *devPtrS = output_S->data.dptr; + + // devPtrdS reuses the memory of devPtrS + void *devPtrdS = devPtrS; + + void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; + void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; + + const auto q_type = input_Q->data.dtype; + const auto kv_type = input_K->data.dtype; + NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); + size_t workspace_size = 0; + + fused_attn_max_512_bwd_impl( + batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, + mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, + devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, + &workspace_size, get_cudnn_dtype(q_type), stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } } } // namespace transformer_engine #endif // CUDNN_VERSION >= 8901 diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index ababf25b9cd8270a026e40f6b153eea319b218c6..a5b25f3279bf812bf3f41175617731311f2b4c23 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -11,11 +11,10 @@ #ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ #define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ -#include "transformer_engine/fused_attn.h" - #include #include "common/common.h" +#include "transformer_engine/fused_attn.h" namespace transformer_engine { #if (CUDNN_VERSION >= 8901) @@ -39,46 +38,42 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t kv_max_seqlen, size_t head_dim, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, + const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, + Tensor *output_dBias, const Tensor *cu_seqlens, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, + const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dKV, Tensor *output_dBias, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, - const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, - Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t kv_max_seqlen, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_dO, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, + Tensor *output_dBias, const Tensor *q_cu_seqlens, + const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8901 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 2851743dd3a1c21c544a8139fa3a5e24e1efe898..0ee3158106b934ba34f640fc0ee361c68f45b3fa 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -5,9 +5,9 @@ ************************************************************************/ #include "../common.h" -#include "utils.h" #include "../util/system.h" #include "fused_attn_fp8.h" +#include "utils.h" namespace transformer_engine { namespace fused_attn { @@ -15,80 +15,75 @@ namespace fused_attn { using namespace transformer_engine; #if (CUDNN_VERSION >= 8900) -std::unordered_map tensor_name_to_uid = { - {"Q", 1}, - {"K", 2}, - {"V", 3}, - {"O", 4}, - {"S", 5}, - {"B", 6}, - {"DROPOUT_SCALE", 7}, - {"S_CONST", 8}, - {"MNK_OVERRIDE", 9}, - {"dQ", 11}, - {"dK", 12}, - {"dV", 13}, - {"dO", 14}, - {"MASK_VAL", 15}, - {"dS", 16}, - {"O_SEQLEN", 17}, - {"M", 18}, - {"Z", 19}, - {"descaleQ", 20}, - {"descaleK", 21}, - {"descaleV", 22}, - {"descaleS", 23}, - {"scaleS", 24}, - {"amaxS", 25}, - {"amaxO", 26}, - {"QKV_RAGGED", 27}, - {"O_RAGGED", 28}, - {"K_TRANSPOSE", 29}, - {"AttnScale", 30}, - {"scaleO", 31}, - {"Z_INV", 32}, - {"descaleO", 33}, - {"descaledO", 34}, - {"descaledS", 35}, - {"descaledQ", 36}, - {"descaledK", 37}, - {"descaledV", 38}, - {"scaledS", 39}, - {"scaledQ", 40}, - {"scaledK", 41}, - {"scaledV", 42}, - {"amaxdS", 43}, - {"amaxdQ", 44}, - {"amaxdK", 45}, - {"amaxdV", 46}, - {"V_TRANSPOSE", 47}, - {"AttnScale_dS_K", 48}, - {"AttnScale_dSTranspose_Q", 49}, - {"DROPOUT_SCALE_dOVt_OdO", 50}, - {"DROPOUT_OFFSET", 51}, - {"DROPOUT_SEED", 52}, - {"VIRTUAL", 80} -}; - -static cudnn_frontend::Tensor createAmax( - const std::string& amax_tensor_name, - const cudnn_frontend::Tensor& prevBlockOutputTensor, - std::vector* ops) { +std::unordered_map tensor_name_to_uid = {{"Q", 1}, + {"K", 2}, + {"V", 3}, + {"O", 4}, + {"S", 5}, + {"B", 6}, + {"DROPOUT_SCALE", 7}, + {"S_CONST", 8}, + {"MNK_OVERRIDE", 9}, + {"dQ", 11}, + {"dK", 12}, + {"dV", 13}, + {"dO", 14}, + {"MASK_VAL", 15}, + {"dS", 16}, + {"O_SEQLEN", 17}, + {"M", 18}, + {"Z", 19}, + {"descaleQ", 20}, + {"descaleK", 21}, + {"descaleV", 22}, + {"descaleS", 23}, + {"scaleS", 24}, + {"amaxS", 25}, + {"amaxO", 26}, + {"QKV_RAGGED", 27}, + {"O_RAGGED", 28}, + {"K_TRANSPOSE", 29}, + {"AttnScale", 30}, + {"scaleO", 31}, + {"Z_INV", 32}, + {"descaleO", 33}, + {"descaledO", 34}, + {"descaledS", 35}, + {"descaledQ", 36}, + {"descaledK", 37}, + {"descaledV", 38}, + {"scaledS", 39}, + {"scaledQ", 40}, + {"scaledK", 41}, + {"scaledV", 42}, + {"amaxdS", 43}, + {"amaxdQ", 44}, + {"amaxdK", 45}, + {"amaxdV", 46}, + {"V_TRANSPOSE", 47}, + {"AttnScale_dS_K", 48}, + {"AttnScale_dSTranspose_Q", 49}, + {"DROPOUT_SCALE_dOVt_OdO", 50}, + {"DROPOUT_OFFSET", 51}, + {"DROPOUT_SEED", 52}, + {"VIRTUAL", 80}}; + +static cudnn_frontend::Tensor createAmax(const std::string& amax_tensor_name, + const cudnn_frontend::Tensor& prevBlockOutputTensor, + std::vector* ops) { int64_t amax_dim[4] = {1, 1, 1, 1}; int64_t amax_stride[4] = {1, 1, 1, 1}; - auto amaxTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid[amax_tensor_name], - amax_dim, amax_stride, false, false); + auto amaxTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid[amax_tensor_name], amax_dim, + amax_stride, false, false); // Define the amax descriptor auto reductionDesc = cudnn_frontend::ReductionDescBuilder() - .setMathPrecision(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_AMAX) - .build(); + .setMathPrecision(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_AMAX) + .build(); // Create a reduction amax Node - auto reduction_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + auto reduction_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) .setxDesc(prevBlockOutputTensor) .setyDesc(amaxTensor) .setreductionDesc(reductionDesc) @@ -97,13 +92,12 @@ static cudnn_frontend::Tensor createAmax( return amaxTensor; } -static cudnn_frontend::Tensor createScale( - const cudnn_frontend::Tensor& prevBlockOutputTensor, - const std::string& scale_tensor_name, - cudnnDataType_t tensorType, - bool isOutputVirtual, bool isScaleByValue, - std::vector* ops, - const std::string& output_tensor_name ="") { +static cudnn_frontend::Tensor createScale(const cudnn_frontend::Tensor& prevBlockOutputTensor, + const std::string& scale_tensor_name, + cudnnDataType_t tensorType, bool isOutputVirtual, + bool isScaleByValue, + std::vector* ops, + const std::string& output_tensor_name = "") { int64_t scale_dim[4] = {1, 1, 1, 1}; int64_t scale_stride[4] = {1, 1, 1, 1}; @@ -111,74 +105,66 @@ static cudnn_frontend::Tensor createScale( int64_t output_stride[4]; for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - output_stride[i] = prevBlockOutputTensor.getStride()[i]; + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + output_stride[i] = prevBlockOutputTensor.getStride()[i]; } - auto scaleTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], - scale_dim, scale_stride, false, isScaleByValue); // is by value + auto scaleTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], + scale_dim, scale_stride, false, isScaleByValue); // is by value - int64_t outputUID = isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] - + tensor_name_to_uid[scale_tensor_name] + 5000 : - tensor_name_to_uid[output_tensor_name]; - auto afterScaleKTensor = tensor_create( - tensorType, outputUID, output_dim, - output_stride, isOutputVirtual, false); // is virtual + int64_t outputUID = + isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 5000 + : tensor_name_to_uid[output_tensor_name]; + auto afterScaleKTensor = tensor_create(tensorType, outputUID, output_dim, output_stride, + isOutputVirtual, false); // is virtual // Define the scale descriptor auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a Scale Node - auto scale_op = binary_pw_op_create( - prevBlockOutputTensor, scaleTensor, afterScaleKTensor, scaleDesc); + auto scale_op = + binary_pw_op_create(prevBlockOutputTensor, scaleTensor, afterScaleKTensor, scaleDesc); ops->push_back(std::move(scale_op)); return afterScaleKTensor; } -static cudnn_frontend::Tensor createScale( - const cudnn_frontend::Tensor& prevBlockOutputTensor, - const cudnn_frontend::Tensor& scaleTensor, - cudnnDataType_t tensorType, - bool isOutputVirtual, bool isScaleByValue, - std::vector* ops, - int UID_offset, const std::string& output_tensor_name ="") { +static cudnn_frontend::Tensor createScale(const cudnn_frontend::Tensor& prevBlockOutputTensor, + const cudnn_frontend::Tensor& scaleTensor, + cudnnDataType_t tensorType, bool isOutputVirtual, + bool isScaleByValue, + std::vector* ops, + int UID_offset, + const std::string& output_tensor_name = "") { int64_t output_dim[4]; int64_t output_stride[4]; for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - output_stride[i] = prevBlockOutputTensor.getStride()[i]; + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + output_stride[i] = prevBlockOutputTensor.getStride()[i]; } - int64_t outputUID = isOutputVirtual ? - tensor_name_to_uid["VIRTUAL"] + UID_offset : - tensor_name_to_uid[output_tensor_name]; - auto afterScaleTensor = tensor_create( - tensorType, outputUID, output_dim, - output_stride, isOutputVirtual, false); // is virtual + int64_t outputUID = isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + UID_offset + : tensor_name_to_uid[output_tensor_name]; + auto afterScaleTensor = tensor_create(tensorType, outputUID, output_dim, output_stride, + isOutputVirtual, false); // is virtual // Define the scale descriptor auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a Scale Node - auto scale_op = binary_pw_op_create( - prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); + auto scale_op = + binary_pw_op_create(prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); ops->push_back(std::move(scale_op)); return afterScaleTensor; } static cudnn_frontend::Tensor createScaleWithOffset( - const cudnn_frontend::Tensor& prevBlockOutputTensor, - const std::string& scale_tensor_name, - NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, - bool isOutputVirtual, - bool isScaleByValue, - std::vector* ops, - std::shared_ptr offsetTensor, - const std::string& output_tensor_name ="") { + const cudnn_frontend::Tensor& prevBlockOutputTensor, const std::string& scale_tensor_name, + NVTE_QKV_Layout layout, cudnnDataType_t tensorType, bool isOutputVirtual, bool isScaleByValue, + std::vector* ops, + std::shared_ptr offsetTensor, + const std::string& output_tensor_name = "") { int64_t scale_dim[4] = {1, 1, 1, 1}; int64_t scale_stride[4] = {1, 1, 1, 1}; @@ -186,49 +172,45 @@ static cudnn_frontend::Tensor createScaleWithOffset( int64_t output_stride[4]; // If output tensor is dQ, dK, or dV, we need to generate QKV interleaved strides if (output_tensor_name == "dQ" || output_tensor_name == "dK" || output_tensor_name == "dV") { - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - } - generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2], - 0 /*s_kv = 0 for placeholder*/, - output_dim[3], output_stride, - layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + } + generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2], + 0 /*s_kv = 0 for placeholder*/, output_dim[3], output_stride, layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); } else { - // Otherwise output dim and stride should be the same as prev block dim and stride - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - output_stride[i] = prevBlockOutputTensor.getStride()[i]; - } + // Otherwise output dim and stride should be the same as prev block dim and stride + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + output_stride[i] = prevBlockOutputTensor.getStride()[i]; + } } - auto scaleTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], - scale_dim, scale_stride, false, isScaleByValue); // is by value + auto scaleTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], + scale_dim, scale_stride, false, isScaleByValue); // is by value cudnnDataType_t outputDataType = isOutputVirtual ? CUDNN_DATA_FLOAT : tensorType; - int64_t outputUID = isOutputVirtual ? - tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 7000 : - tensor_name_to_uid[output_tensor_name]; - auto afterScaleTensor = tensor_create_with_offset( - outputDataType, outputUID, output_dim, - output_stride, isOutputVirtual, false, offsetTensor); // is virtual + int64_t outputUID = + isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 7000 + : tensor_name_to_uid[output_tensor_name]; + auto afterScaleTensor = + tensor_create_with_offset(outputDataType, outputUID, output_dim, output_stride, + isOutputVirtual, false, offsetTensor); // is virtual // Define the scale descriptor auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a Scale Node - auto scale_op = binary_pw_op_create( - prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); + auto scale_op = + binary_pw_op_create(prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); ops->push_back(std::move(scale_op)); return afterScaleTensor; } static cudnn_frontend::Tensor createSoftmaxForward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - std::vector* ops, - const cudnn_frontend::Tensor& prevBlockOutputTensor, - bool isTraining) { + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, std::vector* ops, + const cudnn_frontend::Tensor& prevBlockOutputTensor, bool isTraining) { int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; @@ -236,33 +218,30 @@ static cudnn_frontend::Tensor createSoftmaxForward( int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; // max (x) (M tensor) - auto afterMaxReductionTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], - afterReduction_dim, afterReduction_stride, - !isTraining, false); // not virtual if training is true, - // virtual if training is false + auto afterMaxReductionTensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], afterReduction_dim, + afterReduction_stride, !isTraining, false); // not virtual if training is true, + // virtual if training is false // x - max(x) - auto afterSubtractionTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 151, - afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + auto afterSubtractionTensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 151, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual // e^(x - max(x)) - auto afterExponentTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 152, - afterBMM1_dim, afterBMM1_stride, true, false); // is virtual; + auto afterExponentTensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 152, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual; // sum (e^(x - max(x))) (Z tensor) - auto zTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["Z"], - afterReduction_dim, afterReduction_stride, true, false); // is virtual + auto zTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["Z"], afterReduction_dim, + afterReduction_stride, true, false); // is virtual // 1 / sum (e^(x - max(x))) (Z_INV tensor) - auto zInvTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], - afterReduction_dim, afterReduction_stride, - !isTraining, false); // not virtual if training is true, - // virtual if training is false + auto zInvTensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], afterReduction_dim, + afterReduction_stride, !isTraining, false); // not virtual if training is true, + // virtual if training is false // Final softmax output (After exponent * Z_INV) - auto beforeDropoutTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 153, - afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + auto beforeDropoutTensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 153, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual // Define the reduction descriptor auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() @@ -271,27 +250,25 @@ static cudnn_frontend::Tensor createSoftmaxForward( .build(); // Create a reduction max Node - auto reductionMax_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(prevBlockOutputTensor) - .setyDesc(afterMaxReductionTensor) - .setreductionDesc(reductionMaxDesc) - .build(); + auto reductionMax_op = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(prevBlockOutputTensor) + .setyDesc(afterMaxReductionTensor) + .setreductionDesc(reductionMaxDesc) + .build(); // Define the subtract descriptor auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); // Create a subtract Node - auto subtract_op = binary_pw_op_create( - prevBlockOutputTensor, afterMaxReductionTensor, - afterSubtractionTensor, subtractDesc); + auto subtract_op = binary_pw_op_create(prevBlockOutputTensor, afterMaxReductionTensor, + afterSubtractionTensor, subtractDesc); // Define the exponent descriptor auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); // Create a exponent Node - auto exponent_op = unary_pw_op_create( - afterSubtractionTensor, afterExponentTensor, exponentDesc); + auto exponent_op = unary_pw_op_create(afterSubtractionTensor, afterExponentTensor, exponentDesc); // Define the reduction descriptor auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() @@ -300,12 +277,12 @@ static cudnn_frontend::Tensor createSoftmaxForward( .build(); // Create a reduction add Node - auto reductionAdd_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(afterExponentTensor) - .setyDesc(zTensor) - .setreductionDesc(reductionAddDesc) - .build(); + auto reductionAdd_op = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(afterExponentTensor) + .setyDesc(zTensor) + .setreductionDesc(reductionAddDesc) + .build(); // Define the reciprocal descriptor auto reciprocalDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL); @@ -317,8 +294,8 @@ static cudnn_frontend::Tensor createSoftmaxForward( auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a multiply Node - auto mutliply_op = binary_pw_op_create( - afterExponentTensor, zInvTensor, beforeDropoutTensor, multiplyDesc); + auto mutliply_op = + binary_pw_op_create(afterExponentTensor, zInvTensor, beforeDropoutTensor, multiplyDesc); ops->push_back(std::move(reductionMax_op)); ops->push_back(std::move(subtract_op)); @@ -331,12 +308,10 @@ static cudnn_frontend::Tensor createSoftmaxForward( } static cudnn_frontend::Tensor createDropoutForward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - double probability, - std::vector* ops, - const cudnn_frontend::Tensor& beforeDropoutTensor) { - NVTE_CHECK(ops->size() > 0, - "Dropout DAG constructed incorrectly as the first one"); + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, double probability, + std::vector* ops, + const cudnn_frontend::Tensor& beforeDropoutTensor) { + NVTE_CHECK(ops->size() > 0, "Dropout DAG constructed incorrectly as the first one"); int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; @@ -345,18 +320,17 @@ static cudnn_frontend::Tensor createDropoutForward( int64_t scale_stride[4] = {1, 1, 1, 1}; // Mask for the dropout - auto dropoutMaskTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 250, - afterBMM1_dim, afterBMM1_stride, true, false); // is virtual - auto dropoutSeedTensor = tensor_create( - CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], - scale_dim, scale_stride, false, false); // is by value - auto dropoutOffsetTensor = tensor_create( - CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], - scale_dim, scale_stride, false, false); // is by value + auto dropoutMaskTensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 250, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + auto dropoutSeedTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], + scale_dim, scale_stride, false, false); // is by value + auto dropoutOffsetTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], + scale_dim, scale_stride, false, false); // is by value // After dropout tensor befor scale - auto beforeDropoutScaleTensor = cudnn_frontend::TensorBuilder() + auto beforeDropoutScaleTensor = + cudnn_frontend::TensorBuilder() .setDim(4, afterBMM1_dim) .setStride(4, afterBMM1_stride) .setId(tensor_name_to_uid["VIRTUAL"] + 201) @@ -367,44 +341,40 @@ static cudnn_frontend::Tensor createDropoutForward( .setReorderType(cudnn_frontend::TensorReordering_t::F16x16) .build(); // Scale after dropout - auto scaleDropoutTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], - scale_dim, scale_stride, false, true); // is by value + auto scaleDropoutTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], + scale_dim, scale_stride, false, true); // is by value // After Scale - auto afterDropout_before_quan_S = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, - afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + auto afterDropout_before_quan_S = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual // Define the reduction descriptor auto rngDesc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - probability) - .build(); + .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) + .setBernoulliDistProbability(1.0 - probability) + .build(); // Create a rng Node auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(dropoutMaskTensor) - .setSeedDesc(dropoutSeedTensor) - .setOffsetDesc(dropoutOffsetTensor) - .setRngDesc(rngDesc) - .build(); - + .setyDesc(dropoutMaskTensor) + .setSeedDesc(dropoutSeedTensor) + .setOffsetDesc(dropoutOffsetTensor) + .setRngDesc(rngDesc) + .build(); // Define the multiply mask descriptor auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a multiply mask Node - auto maskMul_op = binary_pw_op_create( - beforeDropoutTensor, dropoutMaskTensor, - beforeDropoutScaleTensor, maskMulDesc); + auto maskMul_op = binary_pw_op_create(beforeDropoutTensor, dropoutMaskTensor, + beforeDropoutScaleTensor, maskMulDesc); // Define the multiply scale descriptor auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a multiply mask Node - auto scaleMul_op = binary_pw_op_create( - beforeDropoutScaleTensor, scaleDropoutTensor, - afterDropout_before_quan_S, scaleMulDesc); + auto scaleMul_op = binary_pw_op_create(beforeDropoutScaleTensor, scaleDropoutTensor, + afterDropout_before_quan_S, scaleMulDesc); ops->push_back(std::move(rng_op)); ops->push_back(std::move(maskMul_op)); @@ -414,13 +384,10 @@ static cudnn_frontend::Tensor createDropoutForward( } static cudnn_frontend::Tensor createDropoutBackward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - double probability, - std::vector* ops, - const cudnn_frontend::Tensor& beforeDropoutTensor, - const cudnn_frontend::Tensor& dropoutMaskTensor) { - NVTE_CHECK(ops->size() > 0, - "Dropout DAG constructed incorrectly as the first one"); + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, double probability, + std::vector* ops, const cudnn_frontend::Tensor& beforeDropoutTensor, + const cudnn_frontend::Tensor& dropoutMaskTensor) { + NVTE_CHECK(ops->size() > 0, "Dropout DAG constructed incorrectly as the first one"); int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; @@ -428,15 +395,14 @@ static cudnn_frontend::Tensor createDropoutBackward( int64_t scale_dim[4] = {1, 1, 1, 1}; int64_t scale_stride[4] = {1, 1, 1, 1}; - auto dropoutSeedTensor = tensor_create( - CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], - scale_dim, scale_stride, false, false); // is by value - auto dropoutOffsetTensor = tensor_create( - CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], - scale_dim, scale_stride, false, false); // is by value + auto dropoutSeedTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], + scale_dim, scale_stride, false, false); // is by value + auto dropoutOffsetTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], + scale_dim, scale_stride, false, false); // is by value // After dropout tensor befor scale - auto beforeDropoutScaleTensor = cudnn_frontend::TensorBuilder() + auto beforeDropoutScaleTensor = + cudnn_frontend::TensorBuilder() .setDim(4, afterBMM1_dim) .setStride(4, afterBMM1_stride) .setId(tensor_name_to_uid["VIRTUAL"] + 201) @@ -447,43 +413,40 @@ static cudnn_frontend::Tensor createDropoutBackward( .setReorderType(cudnn_frontend::TensorReordering_t::F16x16) .build(); // Scale after dropout (1 / (1 - p)) - auto scaleDropoutTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], - scale_dim, scale_stride, false, true); // is by value + auto scaleDropoutTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], + scale_dim, scale_stride, false, true); // is by value // After Scale - auto afterDropout_before_quan_S = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, - afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + auto afterDropout_before_quan_S = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual // Define the reduction descriptor auto rngDesc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - probability) - .build(); + .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) + .setBernoulliDistProbability(1.0 - probability) + .build(); // Create a rng Node auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(dropoutMaskTensor) - .setSeedDesc(dropoutSeedTensor) - .setOffsetDesc(dropoutOffsetTensor) - .setRngDesc(rngDesc) - .build(); + .setyDesc(dropoutMaskTensor) + .setSeedDesc(dropoutSeedTensor) + .setOffsetDesc(dropoutOffsetTensor) + .setRngDesc(rngDesc) + .build(); // Define the multiply mask descriptor auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a multiply mask Node - auto maskMul_op = binary_pw_op_create( - beforeDropoutTensor, dropoutMaskTensor, - beforeDropoutScaleTensor, maskMulDesc); + auto maskMul_op = binary_pw_op_create(beforeDropoutTensor, dropoutMaskTensor, + beforeDropoutScaleTensor, maskMulDesc); // Define the multiply scale descriptor auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a multiply mask Node - auto scaleMul_op = binary_pw_op_create( - beforeDropoutScaleTensor, scaleDropoutTensor, - afterDropout_before_quan_S, scaleMulDesc); + auto scaleMul_op = binary_pw_op_create(beforeDropoutScaleTensor, scaleDropoutTensor, + afterDropout_before_quan_S, scaleMulDesc); ops->push_back(std::move(rng_op)); ops->push_back(std::move(maskMul_op)); @@ -492,12 +455,10 @@ static cudnn_frontend::Tensor createDropoutBackward( return afterDropout_before_quan_S; } -static cudnn_frontend::Tensor createSoftmaxBackward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - std::vector* ops, - const cudnn_frontend::Tensor& dyTensor) { - NVTE_CHECK(ops->size() > 0, - "Softmax backward constructed incorrectly as the first one"); +static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + std::vector* ops, + const cudnn_frontend::Tensor& dyTensor) { + NVTE_CHECK(ops->size() > 0, "Softmax backward constructed incorrectly as the first one"); int64_t dx_dim[4] = {b, h, s_q, s_kv}; int64_t dx_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; @@ -506,41 +467,38 @@ static cudnn_frontend::Tensor createSoftmaxBackward( int64_t M_Z_stride[4] = {h * s_q, s_q, 1, 1}; // Creating all tensors - auto MTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], - M_Z_dim, M_Z_stride, false, false); // not virtual - auto ZInvTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], - M_Z_dim, M_Z_stride, false, false); // not virtual - auto dxAfterSubtractionTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 252, - dx_dim, dx_stride, true, false); // is virtual - auto dxAfterExponentiation = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 253, - dx_dim, dx_stride, true, false); // is virtual - auto dxBeforeDropout_QKt_Tensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 254, - dx_dim, dx_stride, true, false); // is virtual + auto MTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], M_Z_dim, M_Z_stride, + false, false); // not virtual + auto ZInvTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], M_Z_dim, + M_Z_stride, false, false); // not virtual + auto dxAfterSubtractionTensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 252, dx_dim, dx_stride, true, + false); // is virtual + auto dxAfterExponentiation = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 253, + dx_dim, dx_stride, true, false); // is virtual + auto dxBeforeDropout_QKt_Tensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 254, dx_dim, dx_stride, true, + false); // is virtual // Creating all ops // sub (dy - M) auto subtractionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - auto subtractionOp = binary_pw_op_create( - dyTensor, MTensor, dxAfterSubtractionTensor, subtractionDesc); + auto subtractionOp = + binary_pw_op_create(dyTensor, MTensor, dxAfterSubtractionTensor, subtractionDesc); // Define the exponent descriptor auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); // Create a exponent Node. (exp(dy - M)) - auto exponentOp = unary_pw_op_create( - dxAfterSubtractionTensor, dxAfterExponentiation, exponentDesc); + auto exponentOp = + unary_pw_op_create(dxAfterSubtractionTensor, dxAfterExponentiation, exponentDesc); // Define the pw multiply descriptor auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a multiply Node - auto mutliplyOp = binary_pw_op_create( - dxAfterExponentiation, ZInvTensor, dxBeforeDropout_QKt_Tensor, multiplyDesc); + auto mutliplyOp = binary_pw_op_create(dxAfterExponentiation, ZInvTensor, + dxBeforeDropout_QKt_Tensor, multiplyDesc); ops->push_back(std::move(subtractionOp)); ops->push_back(std::move(exponentOp)); @@ -550,58 +508,50 @@ static cudnn_frontend::Tensor createSoftmaxBackward( } static cudnn_frontend::Tensor createQKBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, - std::vector* ops, - const cudnn_frontend::Tensor &qTensor, - const cudnn_frontend::Tensor &kTensor, - const cudnn_frontend::Tensor &mnkOverride, - std::shared_ptr QKVRaggedOffsetTensor) { + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, std::vector* ops, + const cudnn_frontend::Tensor& qTensor, const cudnn_frontend::Tensor& kTensor, + const cudnn_frontend::Tensor& mnkOverride, + std::shared_ptr QKVRaggedOffsetTensor) { // Creates the necessary tensor descriptors int64_t k_transpose_dim[4] = {b, h, d, s_kv}; int64_t k_transpose_stride[4]; - generateMatrixStrides( - b, h, s_q, s_kv, d, - k_transpose_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + generateMatrixStrides(b, h, s_q, s_kv, d, k_transpose_stride, layout, + NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); int64_t s_dim[4] = {b, h, s_q, s_kv}; int64_t s_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - auto kTransposeTensor = tensor_create_with_offset( - tensorType, tensor_name_to_uid["K_TRANSPOSE"], - k_transpose_dim, k_transpose_stride, - false, false, QKVRaggedOffsetTensor); // is virtual + auto kTransposeTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["K_TRANSPOSE"], + k_transpose_dim, k_transpose_stride, false, + false, QKVRaggedOffsetTensor); // is virtual // First GEMM output - auto afterQKTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 1, - s_dim, s_stride, true, false); // is virtual + auto afterQKTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 1, s_dim, + s_stride, true, false); // is virtual // Define the matmul desc auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(-2000000) - .build(); + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(-2000000) + .build(); // Create reshape node for K -> K.T - auto reshape_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(kTensor) - .setyDesc(kTransposeTensor) - .build(); + auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(kTensor) + .setyDesc(kTransposeTensor) + .build(); // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(qTensor) - .setbMatDesc(kTransposeTensor) - .setcMatDesc(afterQKTensor) - .setmOverrideDesc(mnkOverride) - .setnOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); + auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(qTensor) + .setbMatDesc(kTransposeTensor) + .setcMatDesc(afterQKTensor) + .setmOverrideDesc(mnkOverride) + .setnOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); ops->push_back(std::move(reshape_op)); ops->push_back(std::move(matmulOp)); @@ -610,102 +560,87 @@ static cudnn_frontend::Tensor createQKBMM( } static cudnn_frontend::Tensor createSVBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, - std::vector* ops, - const cudnn_frontend::Tensor &softmaxTensor, - const cudnn_frontend::Tensor &mnkOverride, - std::shared_ptr QKVRaggedOffsetTensor) { - NVTE_CHECK(ops->size() > 0, - "BMM2 op constructed incorrectly as the first one"); - - int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, std::vector* ops, + const cudnn_frontend::Tensor& softmaxTensor, const cudnn_frontend::Tensor& mnkOverride, + std::shared_ptr QKVRaggedOffsetTensor) { + NVTE_CHECK(ops->size() > 0, "BMM2 op constructed incorrectly as the first one"); + + int64_t v_dim[4] = {b, h, s_kv, d}; int64_t v_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - int64_t o_dim[4] = {b, h, s_q, d}; + int64_t o_dim[4] = {b, h, s_q, d}; int64_t o_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - auto vTensor = tensor_create_with_offset( - tensorType, tensor_name_to_uid["V"], - v_dim, v_stride, false, false, QKVRaggedOffsetTensor); + auto vTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["V"], v_dim, v_stride, + false, false, QKVRaggedOffsetTensor); // Second fprop GEMM output - auto oTensor = tensor_create( - tensorType, tensor_name_to_uid["VIRTUAL"] + 300, - o_dim, o_stride, true, false); // is virtual + auto oTensor = tensor_create(tensorType, tensor_name_to_uid["VIRTUAL"] + 300, o_dim, o_stride, + true, false); // is virtual // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .build(); + auto matmulDesc = cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(softmaxTensor) - .setbMatDesc(vTensor) - .setcMatDesc(oTensor) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); + auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(softmaxTensor) + .setbMatDesc(vTensor) + .setcMatDesc(oTensor) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); ops->push_back(std::move(matmulOp)); return oTensor; } -static cudnn_frontend::Tensor createSdOBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - cudnnDataType_t tensorType, - std::vector* ops, - const cudnn_frontend::Tensor &softmaxTensor, - const cudnn_frontend::Tensor &dOTensor, - const cudnn_frontend::Tensor &mnkOverride) { - NVTE_CHECK(ops->size() > 0, - "BMM2 op constructed incorrectly as the first one"); - - int64_t s_dim_transpose[4] = {b, h, s_kv, s_q}; +static cudnn_frontend::Tensor createSdOBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + int64_t d, cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor& softmaxTensor, + const cudnn_frontend::Tensor& dOTensor, + const cudnn_frontend::Tensor& mnkOverride) { + NVTE_CHECK(ops->size() > 0, "BMM2 op constructed incorrectly as the first one"); + + int64_t s_dim_transpose[4] = {b, h, s_kv, s_q}; int64_t s_stride_transpose[4] = {h * s_kv * s_q, s_kv * s_q, 1, s_kv}; - int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_dim[4] = {b, h, s_kv, d}; int64_t v_stride[4] = {h * s_kv * d, d, h * d, 1}; - auto sTransposeTensor = tensor_create( - tensorType, tensor_name_to_uid["VIRTUAL"] + 499, - s_dim_transpose, s_stride_transpose, - true, false); // is virtual + auto sTransposeTensor = + tensor_create(tensorType, tensor_name_to_uid["VIRTUAL"] + 499, s_dim_transpose, + s_stride_transpose, true, false); // is virtual // S.T * dO - auto dVTensor_before_dequan_S = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 500, - v_dim, v_stride, - true, false); // is virtual + auto dVTensor_before_dequan_S = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 500, v_dim, v_stride, true, + false); // is virtual // Create reshape node for softmax -> softmax.T - auto reshape_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(softmaxTensor) - .setyDesc(sTransposeTensor) - .build(); + auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(softmaxTensor) + .setyDesc(sTransposeTensor) + .build(); // Define the matmul desc auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(sTransposeTensor) - .setbMatDesc(dOTensor) - .setcMatDesc(dVTensor_before_dequan_S) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); + auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(sTransposeTensor) + .setbMatDesc(dOTensor) + .setcMatDesc(dVTensor_before_dequan_S) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); ops->push_back(std::move(reshape_op)); ops->push_back(std::move(matmulOp)); @@ -714,15 +649,12 @@ static cudnn_frontend::Tensor createSdOBMM( } static cudnn_frontend::Tensor createdOVBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, - std::vector* ops, - const cudnn_frontend::Tensor &dOTensor, - const cudnn_frontend::Tensor &mnkOverride, - std::shared_ptr QKVRaggedOffsetTensor) { + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, std::vector* ops, + const cudnn_frontend::Tensor& dOTensor, const cudnn_frontend::Tensor& mnkOverride, + std::shared_ptr QKVRaggedOffsetTensor) { // Creates the necessary tensor descriptors - int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_dim[4] = {b, h, s_kv, d}; int64_t v_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); @@ -737,43 +669,37 @@ static cudnn_frontend::Tensor createdOVBMM( int64_t s_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - auto vTensor = tensor_create_with_offset( - tensorType, tensor_name_to_uid["V"], - v_dim, v_stride, - false, false, QKVRaggedOffsetTensor); - auto vTransposeTensor = tensor_create_with_offset( - tensorType, tensor_name_to_uid["V_TRANSPOSE"], - v_transpose_dim, v_transpose_stride, - false, false, QKVRaggedOffsetTensor); // is virtual + auto vTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["V"], v_dim, v_stride, + false, false, QKVRaggedOffsetTensor); + auto vTransposeTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["V_TRANSPOSE"], + v_transpose_dim, v_transpose_stride, false, + false, QKVRaggedOffsetTensor); // is virtual // dO * V.T - auto afterdOVTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 600, - s_dim, s_stride, true, false); // is virtual + auto afterdOVTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 600, s_dim, + s_stride, true, false); // is virtual // Define the matmul desc auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); // Create reshape node for V -> V.T - auto reshape_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(vTensor) - .setyDesc(vTransposeTensor) - .build(); + auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(vTensor) + .setyDesc(vTransposeTensor) + .build(); // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dOTensor) - .setbMatDesc(vTransposeTensor) - .setcMatDesc(afterdOVTensor) - .setmOverrideDesc(mnkOverride) - .setnOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); + auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dOTensor) + .setbMatDesc(vTransposeTensor) + .setcMatDesc(afterdOVTensor) + .setmOverrideDesc(mnkOverride) + .setnOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); ops->push_back(std::move(reshape_op)); ops->push_back(std::move(matmulOp)); @@ -782,40 +708,37 @@ static cudnn_frontend::Tensor createdOVBMM( } static cudnn_frontend::Tensor createdOAndORowReductionChain( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, - std::vector* ops, - const cudnn_frontend::Tensor &O_after_dequan, - const cudnn_frontend::Tensor &dO_after_dequan, - const cudnn_frontend::Tensor &dropoutScale_dOVt_OdO_Tensor) { + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, + std::vector* ops, const cudnn_frontend::Tensor& O_after_dequan, + const cudnn_frontend::Tensor& dO_after_dequan, + const cudnn_frontend::Tensor& dropoutScale_dOVt_OdO_Tensor) { int64_t o_dim[4] = {b, h, s_q, d}; int64_t o_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); int64_t o_dim_row_sum[4] = {b, h, s_q, 1}; int64_t o_dim_row_sum_stride[4] = {s_q * h, s_q, 1, 1}; - auto O_dO_after_pointwise_multiply = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 700, - o_dim, o_stride, true, false); // is virtual - auto O_dO_after_dropout_scale = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 701, - o_dim, o_stride, true, false); // is virtual - auto O_dO_after_rowsum = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 702, - o_dim_row_sum, o_dim_row_sum_stride, true, false); // is virtual + auto O_dO_after_pointwise_multiply = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 700, o_dim, o_stride, true, + false); // is virtual + auto O_dO_after_dropout_scale = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 701, o_dim, o_stride, true, + false); // is virtual + auto O_dO_after_rowsum = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 702, o_dim_row_sum, + o_dim_row_sum_stride, true, false); // is virtual // Define the pw multiply descriptor auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // Create a multiply Node - auto mutliply_op = binary_pw_op_create( - O_after_dequan, dO_after_dequan, - O_dO_after_pointwise_multiply, multiplyDesc); + auto mutliply_op = binary_pw_op_create(O_after_dequan, dO_after_dequan, + O_dO_after_pointwise_multiply, multiplyDesc); // Create multiply node with dropout scale - auto dropout_scale_multiply_op = binary_pw_op_create( - O_dO_after_pointwise_multiply, dropoutScale_dOVt_OdO_Tensor, - O_dO_after_dropout_scale, multiplyDesc); + auto dropout_scale_multiply_op = + binary_pw_op_create(O_dO_after_pointwise_multiply, dropoutScale_dOVt_OdO_Tensor, + O_dO_after_dropout_scale, multiplyDesc); // Define the reduction descriptor auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() @@ -824,12 +747,12 @@ static cudnn_frontend::Tensor createdOAndORowReductionChain( .build(); // Create a reduction add Node - auto reductionAdd_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(O_dO_after_dropout_scale) - .setyDesc(O_dO_after_rowsum) - .setreductionDesc(reductionAddDesc) - .build(); + auto reductionAdd_op = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(O_dO_after_dropout_scale) + .setyDesc(O_dO_after_rowsum) + .setreductionDesc(reductionAddDesc) + .build(); ops->push_back(std::move(mutliply_op)); ops->push_back(std::move(dropout_scale_multiply_op)); @@ -839,45 +762,37 @@ static cudnn_frontend::Tensor createdOAndORowReductionChain( } static cudnn_frontend::Tensor createBiasSubtractionSoftmaxMulChain( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, - std::vector* ops, - const cudnn_frontend::Tensor &dS_after_dropout, - const cudnn_frontend::Tensor &AfterDropout_before_quan_S, - const cudnn_frontend::Tensor &O_dO_after_rowsum, - const cudnn_frontend::Tensor &attnScale) { + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, + std::vector* ops, const cudnn_frontend::Tensor& dS_after_dropout, + const cudnn_frontend::Tensor& AfterDropout_before_quan_S, + const cudnn_frontend::Tensor& O_dO_after_rowsum, const cudnn_frontend::Tensor& attnScale) { int64_t o_dim[4] = {b, h, s_q, s_kv}; int64_t o_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - auto dS_minus_O_dO = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 800, - o_dim, o_stride, true, false); // is virtual - auto AfterAttnScale_before_dS = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 801, - o_dim, o_stride, true, false); // is virtual - auto S_mul_dS_minus_O_dO = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 802, - o_dim, o_stride, true, false); // is virtual + auto dS_minus_O_dO = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 800, o_dim, + o_stride, true, false); // is virtual + auto AfterAttnScale_before_dS = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 801, o_dim, o_stride, true, + false); // is virtual + auto S_mul_dS_minus_O_dO = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 802, + o_dim, o_stride, true, false); // is virtual // Define the pw subtraction descriptor auto subDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); // Create a subtraction Node - auto sub_op = binary_pw_op_create( - dS_after_dropout, O_dO_after_rowsum, dS_minus_O_dO, subDesc); + auto sub_op = binary_pw_op_create(dS_after_dropout, O_dO_after_rowsum, dS_minus_O_dO, subDesc); // Define the pw multiplication descriptor auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); // dS_minus_O_dO * attnScale - auto mutliply_attn_scale_op = binary_pw_op_create( - dS_minus_O_dO, attnScale, - AfterAttnScale_before_dS, multiplyDesc); + auto mutliply_attn_scale_op = + binary_pw_op_create(dS_minus_O_dO, attnScale, AfterAttnScale_before_dS, multiplyDesc); // AfterDropout_before_quan_S * AfterAttnScale_before_dS - auto mutliply_op = binary_pw_op_create( - AfterDropout_before_quan_S, AfterAttnScale_before_dS, - S_mul_dS_minus_O_dO, multiplyDesc); + auto mutliply_op = binary_pw_op_create(AfterDropout_before_quan_S, AfterAttnScale_before_dS, + S_mul_dS_minus_O_dO, multiplyDesc); ops->push_back(std::move(sub_op)); ops->push_back(std::move(mutliply_attn_scale_op)); @@ -886,49 +801,45 @@ static cudnn_frontend::Tensor createBiasSubtractionSoftmaxMulChain( return S_mul_dS_minus_O_dO; } -static cudnn_frontend::Tensor createdSKBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - std::vector* ops, - const cudnn_frontend::Tensor &dSTensor, - const cudnn_frontend::Tensor &kTensor, - const cudnn_frontend::Tensor &mnkOverride) { +static cudnn_frontend::Tensor createdSKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + int64_t d, std::vector* ops, + const cudnn_frontend::Tensor& dSTensor, + const cudnn_frontend::Tensor& kTensor, + const cudnn_frontend::Tensor& mnkOverride) { // Creates the necessary tensor descriptors int64_t after_dSK_dim[4] = {b, h, s_kv, d}; int64_t after_dSK_stride[4] = {h * s_kv * d, d, h * d, 1}; // dS * K - auto After_dS_K = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 875, - after_dSK_dim, after_dSK_stride, true, false); // is virtual + auto After_dS_K = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 875, + after_dSK_dim, after_dSK_stride, true, false); // is virtual // Define the matmul desc auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dSTensor) - .setbMatDesc(kTensor) - .setcMatDesc(After_dS_K) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); + auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dSTensor) + .setbMatDesc(kTensor) + .setcMatDesc(After_dS_K) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); ops->push_back(std::move(matmulOp)); return After_dS_K; } -static cudnn_frontend::Tensor createdSQBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, - std::vector* ops, - const cudnn_frontend::Tensor &dSTensor, - const cudnn_frontend::Tensor &qTensor, - const cudnn_frontend::Tensor &mnkOverride) { +static cudnn_frontend::Tensor createdSQBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + int64_t d, NVTE_QKV_Layout layout, + std::vector* ops, + const cudnn_frontend::Tensor& dSTensor, + const cudnn_frontend::Tensor& qTensor, + const cudnn_frontend::Tensor& mnkOverride) { // Creates the necessary tensor descriptors int64_t dS_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, dS_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); @@ -943,39 +854,36 @@ static cudnn_frontend::Tensor createdSQBMM( int64_t after_dSTranspose_Q_dim[4] = {b, h, s_kv, d}; int64_t after_dSTranspose_Q_stride[4] = {h * s_kv * d, d, h * d, 1}; - auto dSTransposeTensor = tensor_create( - CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["VIRTUAL"] + 650, - dS_transpose_dim, dS_transpose_stride, true, false); // is virtual + auto dSTransposeTensor = + tensor_create(CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["VIRTUAL"] + 650, dS_transpose_dim, + dS_transpose_stride, true, false); // is virtual // dS.T * Q - auto After_dSTranspose_Q = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 651, - after_dSTranspose_Q_dim, after_dSTranspose_Q_stride, - true, false); // is virtual + auto After_dSTranspose_Q = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 651, after_dSTranspose_Q_dim, + after_dSTranspose_Q_stride, true, false); // is virtual // Create reshape node for V -> V.T - auto reshape_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(dSTensor) - .setyDesc(dSTransposeTensor) - .build(); + auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(dSTensor) + .setyDesc(dSTransposeTensor) + .build(); // Define the matmul desc auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dSTransposeTensor) - .setbMatDesc(qTensor) - .setcMatDesc(After_dSTranspose_Q) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); + auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dSTransposeTensor) + .setbMatDesc(qTensor) + .setcMatDesc(After_dSTranspose_Q) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); ops->push_back(std::move(reshape_op)); ops->push_back(std::move(matmulOp)); @@ -985,1565 +893,1427 @@ static cudnn_frontend::Tensor createdSQBMM( // fused attention FWD FP8 with FE 0.9 void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - bool isTraining, float attnScale, - float dropoutProbability, NVTE_QKV_Layout layout, - void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, - void* devPtrO, - void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, - void* devPtrAmaxO, void* devPtrAmaxS, - void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - cudnnDataType_t tensorType, - void* workspace_ptr, - size_t* workspace_size, - cudaStream_t stream, - cudnnHandle_t handle_) { + bool isTraining, float attnScale, float dropoutProbability, + NVTE_QKV_Layout layout, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, + void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, + void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnnDataType_t tensorType, void* workspace_ptr, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle_) { try { - FADescriptor descriptor{ - b, h, s_q, s_kv, d, - attnScale, isTraining, dropoutProbability, layout, - NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType, false}; - - using CacheType = std::map; - static thread_local CacheType fa_fprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { - // If hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto plan = it->second; - return plan; - } - - // Otherwise, build the op_graph and the plan. Then update cache - std::vector all_ops; - std::vector ops; - - NVTE_CHECK(dropoutProbability == 0.0f || isTraining, - "Dropout probability should be 0.0f for inference mode"); - NVTE_CHECK(dropoutProbability != 1.0f, - "Dropout probability cannot be 1.0"); - - int64_t raggedDim[4] = {b + 1, 1, 1, 1}; - int64_t raggedStride[4] = {1, 1, 1, 1}; - // Create offset tensors - auto QKVOffsetTensor = tensor_create( - CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], - raggedDim, raggedStride, false, false); - auto ORaggedOffsetTensor = tensor_create( - CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], - raggedDim, raggedStride, false, false); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - // Create override tensors - auto seqlenMNKTensor = tensor_create( - CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], - seqlen_dim, seqlen_stride, false, false); - - // Create shared ptrs to ragged offset tensors - // for multiple tensors to use ragged offset - std::shared_ptr QKVRaggedOffsetTensorPtr = - std::make_shared(std::move(QKVOffsetTensor)); - std::shared_ptr ORaggedOffsetTensorPtr = - std::make_shared(std::move(ORaggedOffsetTensor)); - - // Create Q and K tensors that are used in different places - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + FADescriptor descriptor{b, + h, + s_q, + s_kv, + d, + attnScale, + isTraining, + dropoutProbability, + layout, + NVTE_Bias_Type::NVTE_NO_BIAS, + NVTE_Mask_Type::NVTE_PADDING_MASK, + tensorType, + false}; + + using CacheType = std::map; + static thread_local CacheType fa_fprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_plan = [&](CacheType& cache, const FADescriptor& descriptor) { + // If hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto plan = it->second; + return plan; + } - int64_t k_dim[4] = {b, h, s_kv, d}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - - auto qTensor = tensor_create_with_offset( - tensorType, tensor_name_to_uid["Q"], - q_dim, q_stride, false, false, - QKVRaggedOffsetTensorPtr); - auto kTensor = tensor_create_with_offset( - tensorType, tensor_name_to_uid["K"], - k_dim, k_stride, false, false, - QKVRaggedOffsetTensorPtr); - - // Q * K.T - auto afterQKTensor = createQKBMM( - b, h, s_q, s_kv, d, layout, tensorType, - &ops, qTensor, kTensor, - seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // QK.T * attn scale - auto AfterAttnScale_before_dequan_Q_tensor = createScale( - afterQKTensor, // input tensor - "AttnScale", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - true, // scale is by value - &ops); - - // QK.T * attn scale * dequant_Q - auto AfterAttnScale_before_dequan_K_tensor = createScale( - AfterAttnScale_before_dequan_Q_tensor, // input tensor - "descaleQ", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // QK.T * attn scale * dequant_Q * dequant_K - auto AfterAttnScale_tensor = createScale( - AfterAttnScale_before_dequan_K_tensor, // input tensor - "descaleK", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - auto BeforeDropoutTensor = createSoftmaxForward( - b, h, s_q, s_kv, &ops, - AfterAttnScale_tensor, isTraining); - - auto AfterDropout_before_quan_S = createDropoutForward( - b, h, s_q, s_kv, dropoutProbability, - &ops, BeforeDropoutTensor); - - // Amax for S - createAmax("amaxS", BeforeDropoutTensor, &ops); - - // After softmax * dropout * scale S -> fp8 input to next bmm with V - auto AfterMultiplyDropout = createScale( - AfterDropout_before_quan_S, // input tensor - "scaleS", // scale tensor - tensorType, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // After softmax * Dropout * V - auto OTensor_before_dequan_S_tensor = createSVBMM( - b, h, s_q, s_kv, d, layout, tensorType, - &ops, AfterMultiplyDropout, - seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // O * dequant_S - auto OTensor_before_dequan_V_tensor = createScale( - OTensor_before_dequan_S_tensor, // input tensor - "descaleS", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // O * dequant_S * dequant_V - auto OTensor_before_quan_O_tensor = createScale( - OTensor_before_dequan_V_tensor, // input tensor - "descaleV", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // O * dequant_S * dequant_V * scale O - auto OTensor = createScaleWithOffset( - OTensor_before_quan_O_tensor, // input tensor - "scaleO", // scale tensor - layout, // qkv layout - tensorType, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - ORaggedOffsetTensorPtr, // ragged offset - "O"); - - // Amax for O - createAmax("amaxO", OTensor_before_quan_O_tensor, &ops); - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, - allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, - CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_fprop: No config returned by the heuristics"); - } - - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; // end of get_plan - - auto plan = get_plan(fa_fprop_cache, descriptor); - size_t wkspace_size = static_cast(plan.getWorkspaceSize()); - - // Exit to request upper level API to allocate memory if needed - if (workspace_ptr == nullptr) { - *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); - return; + // Otherwise, build the op_graph and the plan. Then update cache + std::vector all_ops; + std::vector ops; + + NVTE_CHECK(dropoutProbability == 0.0f || isTraining, + "Dropout probability should be 0.0f for inference mode"); + NVTE_CHECK(dropoutProbability != 1.0f, "Dropout probability cannot be 1.0"); + + int64_t raggedDim[4] = {b + 1, 1, 1, 1}; + int64_t raggedStride[4] = {1, 1, 1, 1}; + // Create offset tensors + auto QKVOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], + raggedDim, raggedStride, false, false); + auto ORaggedOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], + raggedDim, raggedStride, false, false); + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + // Create override tensors + auto seqlenMNKTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], + seqlen_dim, seqlen_stride, false, false); + + // Create shared ptrs to ragged offset tensors + // for multiple tensors to use ragged offset + std::shared_ptr QKVRaggedOffsetTensorPtr = + std::make_shared(std::move(QKVOffsetTensor)); + std::shared_ptr ORaggedOffsetTensorPtr = + std::make_shared(std::move(ORaggedOffsetTensor)); + + // Create Q and K tensors that are used in different places + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_dim[4] = {b, h, s_kv, d}; + int64_t k_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + + auto qTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["Q"], q_dim, q_stride, + false, false, QKVRaggedOffsetTensorPtr); + auto kTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["K"], k_dim, k_stride, + false, false, QKVRaggedOffsetTensorPtr); + + // Q * K.T + auto afterQKTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, qTensor, + kTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // QK.T * attn scale + auto AfterAttnScale_before_dequan_Q_tensor = + createScale(afterQKTensor, // input tensor + "AttnScale", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + true, // scale is by value + &ops); + + // QK.T * attn scale * dequant_Q + auto AfterAttnScale_before_dequan_K_tensor = + createScale(AfterAttnScale_before_dequan_Q_tensor, // input tensor + "descaleQ", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // QK.T * attn scale * dequant_Q * dequant_K + auto AfterAttnScale_tensor = + createScale(AfterAttnScale_before_dequan_K_tensor, // input tensor + "descaleK", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + auto BeforeDropoutTensor = + createSoftmaxForward(b, h, s_q, s_kv, &ops, AfterAttnScale_tensor, isTraining); + + auto AfterDropout_before_quan_S = + createDropoutForward(b, h, s_q, s_kv, dropoutProbability, &ops, BeforeDropoutTensor); + + // Amax for S + createAmax("amaxS", BeforeDropoutTensor, &ops); + + // After softmax * dropout * scale S -> fp8 input to next bmm with V + auto AfterMultiplyDropout = createScale(AfterDropout_before_quan_S, // input tensor + "scaleS", // scale tensor + tensorType, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // After softmax * Dropout * V + auto OTensor_before_dequan_S_tensor = + createSVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, AfterMultiplyDropout, + seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // O * dequant_S + auto OTensor_before_dequan_V_tensor = + createScale(OTensor_before_dequan_S_tensor, // input tensor + "descaleS", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // O * dequant_S * dequant_V + auto OTensor_before_quan_O_tensor = + createScale(OTensor_before_dequan_V_tensor, // input tensor + "descaleV", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // O * dequant_S * dequant_V * scale O + auto OTensor = createScaleWithOffset(OTensor_before_quan_O_tensor, // input tensor + "scaleO", // scale tensor + layout, // qkv layout + tensorType, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + ORaggedOffsetTensorPtr, // ragged offset + "O"); + + // Amax for O + createAmax("amaxO", OTensor_before_quan_O_tensor, &ops); + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); } - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); - - int32_t* qkv_ragged_offset = reinterpret_cast( - reinterpret_cast(workspace_ptr) + wkspace_size); - int32_t* o_ragged_offset = reinterpret_cast( - reinterpret_cast(workspace_ptr) - + wkspace_size + (b + 1) * sizeof(int32_t)); - int32_t* actual_seqlens_q = reinterpret_cast( - reinterpret_cast(workspace_ptr) - + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); - // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV - dim3 blockDims(128); - dim3 gridDims((b + blockDims.x)/blockDims.x); - cu_seqlens_to_offsets<<>>( - b, h, d, reinterpret_cast(devPtrcuSeqlensQ), - actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); - void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); - void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); - void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); - - float dropoutScale = 1.0f/(1.0f - dropoutProbability); - - std::set> data_ptrs; - data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["AttnScale"], &attnScale)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["O"], devPtrO)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaleK"], devPtrDescaleK)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaleV"], devPtrDescaleV)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaleS"], devPtrDescaleS)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["scaleS"], devPtrScaleS)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["scaleO"], devPtrScaleO)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["amaxO"], devPtrAmaxO)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["amaxS"], devPtrAmaxS)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); - - // If training, then we need to write out M and Z_INV - if (isTraining) { - data_ptrs.emplace(std::pair( - tensor_name_to_uid["M"], devPtrM)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["Z_INV"], devPtrZInv)); + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_fprop: No config returned by the heuristics"); } - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(data_ptrs) - .build(); + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle_) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + cache.insert({descriptor, plan}); + return plan; + }; // end of get_plan + + auto plan = get_plan(fa_fprop_cache, descriptor); + size_t wkspace_size = static_cast(plan.getWorkspaceSize()); - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_, - plan.get_raw_desc(), - variantPack.get_raw_desc())); + // Exit to request upper level API to allocate memory if needed + if (workspace_ptr == nullptr) { + *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); + return; + } + + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); + + int32_t* qkv_ragged_offset = + reinterpret_cast(reinterpret_cast(workspace_ptr) + wkspace_size); + int32_t* o_ragged_offset = reinterpret_cast(reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * sizeof(int32_t)); + int32_t* actual_seqlens_q = reinterpret_cast( + reinterpret_cast(workspace_ptr) + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); + // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV + dim3 blockDims(128); + dim3 gridDims((b + blockDims.x) / blockDims.x); + cu_seqlens_to_offsets<<>>( + b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, + o_ragged_offset); + void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); + void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); + void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); + + float dropoutScale = 1.0f / (1.0f - dropoutProbability); + + std::set> data_ptrs; + data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["AttnScale"], &attnScale)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["O"], devPtrO)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleK"], devPtrDescaleK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleV"], devPtrDescaleV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleS"], devPtrDescaleS)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["scaleS"], devPtrScaleS)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["scaleO"], devPtrScaleO)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxO"], devPtrAmaxO)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxS"], devPtrAmaxS)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); + + // If training, then we need to write out M and Z_INV + if (isTraining) { + data_ptrs.emplace(std::pair(tensor_name_to_uid["M"], devPtrM)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["Z_INV"], devPtrZInv)); + } + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(data_ptrs) + .build(); + + NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc())); } catch (cudnn_frontend::cudnnException& e) { - struct cudaDeviceProp prop; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); - - // This example is only for GH100 cards (cudnn Version >= 8900) - if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) - && (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH - || e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { - std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; - } else { - std::cout << "[ERROR] Exception " << e.what() << std::endl; - } + struct cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + + // This example is only for GH100 cards (cudnn Version >= 8900) + if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) && + (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH || + e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { + std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; + } else { + std::cout << "[ERROR] Exception " << e.what() << std::endl; + } } } // fused attention BWD FP8 with FE 0.9 -void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, - void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, - void* devPtrO, void* devPtrdO, - void* devPtrdQ, void* devPtrdK, void* devPtrdV, - void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleO, void* devPtrDescaledO, - void* devPtrDescaleS, void* devPtrDescaledS, - void* devPtrScaleS, void* devPtrScaledS, - void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, - void* devPtrAmaxdS, - void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, - void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - cudnnDataType_t tensorType, - void* workspace_ptr, - size_t* workspace_size, - cudaStream_t stream, - cudnnHandle_t handle_) { +void fused_attn_fp8_bwd_impl( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, float attnScale, + float dropoutProbability, NVTE_QKV_Layout layout, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, + void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledS, + void* devPtrScaleS, void* devPtrScaledS, void* devPtrScaledQ, void* devPtrScaledK, + void* devPtrScaledV, void* devPtrAmaxdS, void* devPtrAmaxdQ, void* devPtrAmaxdK, + void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, + void* devPtrDropoutOffset, cudnnDataType_t tensorType, void* workspace_ptr, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle_) { try { - FADescriptor descriptor{ - b, h, s_q, s_kv, d, - attnScale, false, dropoutProbability, layout, - NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType, false}; - - using CacheType = std::map; - static thread_local CacheType fa_bprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { - // If hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto plan = it->second; - return plan; - } - - // Otherwise, build the op_graph and the plan. Then update cache - std::vector all_ops; - std::vector ops; - - NVTE_CHECK(dropoutProbability != 1.0f, - "Dropout probability cannot be 1.0"); - - int64_t raggedDim[4] = {b + 1, 1, 1, 1}; - int64_t raggedStride[4] = {1, 1, 1, 1}; - // Create offset tensors - auto QKVOffsetTensor = tensor_create( - CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], - raggedDim, raggedStride, false, false); - auto ORaggedOffsetTensor = tensor_create( - CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], - raggedDim, raggedStride, false, false); - - // Create shared ptrs to ragged offset tensors for multiple tensors - std::shared_ptr QKVRaggedOffsetTensorPtr = - std::make_shared(std::move(QKVOffsetTensor)); - std::shared_ptr ORaggedOffsetTensorPtr = - std::make_shared(std::move(ORaggedOffsetTensor)); - - // Create Q and K tensors that are used in different places - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + FADescriptor descriptor{b, + h, + s_q, + s_kv, + d, + attnScale, + false, + dropoutProbability, + layout, + NVTE_Bias_Type::NVTE_NO_BIAS, + NVTE_Mask_Type::NVTE_PADDING_MASK, + tensorType, + false}; + + using CacheType = std::map; + static thread_local CacheType fa_bprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_plan = [&](CacheType& cache, const FADescriptor& descriptor) { + // If hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto plan = it->second; + return plan; + } - int64_t k_dim[4] = {b, h, s_kv, d}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - - auto qTensor = tensor_create_with_offset( - tensorType, tensor_name_to_uid["Q"], - q_dim, q_stride, false, false, QKVRaggedOffsetTensorPtr); - auto kTensor = tensor_create_with_offset( - tensorType, tensor_name_to_uid["K"], - k_dim, k_stride, false, false, QKVRaggedOffsetTensorPtr); - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // Create attnScale tensor for multiple ops to use - auto attnScaleTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["AttnScale"], - scale_dim, scale_stride, false, true); // is by value - - // Create descale Q K dO dS global tensors since they are used in multiple places - auto descaleQTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleQ"], - scale_dim, scale_stride, false, false); - auto descaleKTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleK"], - scale_dim, scale_stride, false, false); - auto descaledOTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledO"], - scale_dim, scale_stride, false, false); - auto descaledSTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledS"], - scale_dim, scale_stride, false, false); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - // Create MNK override tensor - auto seqlenMNKTensor = tensor_create( - CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], - seqlen_dim, seqlen_stride, false, false); - - int64_t O_dim[4] = {b, h, s_q, d}; - int64_t O_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, O_stride, layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - // Create O and loss tensor - auto OTensor = tensor_create_with_offset( - tensorType, tensor_name_to_uid["O"], - O_dim, O_stride, false, false, ORaggedOffsetTensorPtr); - // dO is used in multiple places and E5M2 - auto dOTensor = tensor_create_with_offset( - CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["dO"], - O_dim, O_stride, false, false, ORaggedOffsetTensorPtr); - - // Q * K.T - auto afterQKTensor = createQKBMM( - b, h, s_q, s_kv, d, layout, tensorType, - &ops, qTensor, kTensor, - seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // QK.T * attn scale - auto AfterAttnScale_before_dequan_Q_tensor = createScale( - afterQKTensor, // input tensor - attnScaleTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - true, // scale is by value - &ops, - 1999 /*UID offset*/); - - // QK.T * attn scale * dequant_Q - auto AfterAttnScale_before_dequan_K_tensor = createScale( - AfterAttnScale_before_dequan_Q_tensor, // input tensor - descaleQTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, - 2000 /*UID offset*/); - - // QK.T * attn scale * dequant_Q * dequant_K - auto AfterAttnScale_tensor = createScale( - AfterAttnScale_before_dequan_K_tensor, // input tensor - descaleKTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, - 2001 /*UID offset*/); - - auto beforeDropout_QKt_Tensor = createSoftmaxBackward( - b, h, s_q, s_kv, &ops, AfterAttnScale_tensor); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - // mask for the dropout. Used in different places - auto dropoutMaskTensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 200, - afterBMM1_dim, afterBMM1_stride, true, false); // is virtual - - auto AfterDropout_before_quan_S = createDropoutBackward( - b, h, s_q, s_kv, dropoutProbability, - &ops, beforeDropout_QKt_Tensor, dropoutMaskTensor); - - // After softmax * scale S -> fp8 input to next bmm with V - auto AfterMultiply = createScale( - AfterDropout_before_quan_S, // input tensor - "scaleS", // scale tensor - tensorType, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // After softmax * dO - auto dVTensor_before_dequan_S = createSdOBMM( - b, h, s_q, s_kv, d, tensorType, - &ops, AfterMultiply, dOTensor, seqlenMNKTensor); - - // O * dequant_S - auto dVTensor_before_dequan_dO = createScale( - dVTensor_before_dequan_S, // input tensor - "descaleS", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // O * dequant_S * dequant_dO - auto dVTensor_before_quan_dV = createScale( - dVTensor_before_dequan_dO, // input tensor - descaledOTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, - 2002 /*UID offset*/); - - // O * dequant_S * dequant_dO * scale dV - auto dVTensor = createScaleWithOffset( - dVTensor_before_quan_dV, // input tensor - "scaledV", // scale tensor - layout, // qkv layout - CUDNN_DATA_FP8_E5M2, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - QKVRaggedOffsetTensorPtr, // ragged offset - "dV" /*Output tensor name*/); - - // Amax for dV - createAmax("amaxdV", dVTensor_before_quan_dV, &ops); - - auto dS_before_dequan_dO_Tensor = createdOVBMM( - b, h, s_q, s_kv, d, layout, tensorType, - &ops, dOTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // dS * dequant_dO - auto dS_before_dequan_V = createScale( - dS_before_dequan_dO_Tensor, // input tensor - descaledOTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, - 2003 /*UID offset*/); - - // O * dequant_S * dequant_dV - auto dS_after_dequan = createScale( - dS_before_dequan_V, // input tensor - "descaleV", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // RNG Multiply - auto beforeDropoutScale_dOVt_Tensor = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 350, - afterBMM1_dim, afterBMM1_stride, true, false); // is virtual - // After dropout mask and scale - auto dS_after_dropout = tensor_create( - CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 351, - afterBMM1_dim, afterBMM1_stride, true, false); // is virtual - - // Define the multiply mask descriptor - auto mulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto maskMul_op = binary_pw_op_create( - dS_after_dequan, dropoutMaskTensor, - beforeDropoutScale_dOVt_Tensor, mulDesc); - - ops.push_back(std::move(maskMul_op)); - - // scale after dropout for dO and O chain - auto dropoutScale_dOVt_OdO_Tensor = tensor_create( - tensorType, tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], - scale_dim, scale_stride, false, true); // is by value - - // Create a multiply dropout scale Node - auto mul_dropout_scale_op = binary_pw_op_create( - beforeDropoutScale_dOVt_Tensor, - dropoutScale_dOVt_OdO_Tensor, - dS_after_dropout, mulDesc); - - ops.push_back(std::move(mul_dropout_scale_op)); - - // O * dequant_O - auto O_after_dequan_Tensor = createScale(OTensor, // input tensor - "descaleO", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // dO * dequant_dO - auto dO_after_dequan_Tensor = createScale(dOTensor, // input tensor - descaledOTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, - 2004 /*UID offset*/); - - // row reduction sum[(dO * dequant_dO) * (O * dequant_O) * (1 - p)] - auto O_dO_after_rowsum = createdOAndORowReductionChain( - b, h, s_q, s_kv, d, layout, - &ops, O_after_dequan_Tensor, - dO_after_dequan_Tensor, dropoutScale_dOVt_OdO_Tensor); - - // (dS_after_dropout - O_dO_after_rowsum) * AfterDropout_before_quan_S * attnScale - auto S_mul_dS_minus_O_dO = createBiasSubtractionSoftmaxMulChain( - b, h, s_q, s_kv, d, layout, - &ops, dS_after_dropout, - AfterDropout_before_quan_S, O_dO_after_rowsum, - attnScaleTensor); - - - // S_mul_dS_minus_O_dO * scaledS - auto S_mul_dS_minus_O_dO_after_quan_dS = createScale( - S_mul_dS_minus_O_dO, // input tensor - "scaledS", // scale tensor - CUDNN_DATA_FP8_E5M2, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // Amax for dS - createAmax("amaxdS", S_mul_dS_minus_O_dO, &ops); - - // dS @ K - auto After_dS_K = createdSKBMM( - b, h, s_q, s_kv, d, &ops, - S_mul_dS_minus_O_dO_after_quan_dS, - kTensor, seqlenMNKTensor); - - // (dS * K) * descale dS - auto After_dS_K_before_dequan_K = createScale( - After_dS_K, // input tensor - descaledSTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, - 2006 /*UID offset*/); - - // (dS * K) * descale dS * descale K - auto After_dS_K_before_quan_dQ = createScale( - After_dS_K_before_dequan_K, // input tensor - descaleKTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, - 2007 /*UID offset*/); - - // (dS * K) * descale dS * descale K * scale dQ - auto dQ = createScaleWithOffset( - After_dS_K_before_quan_dQ, // input tensor - "scaledQ", // scale tensor - layout, // qkv layout - CUDNN_DATA_FP8_E5M2, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - QKVRaggedOffsetTensorPtr, // ragged offset - "dQ"); - - // Amax for dQ - createAmax("amaxdQ", After_dS_K_before_quan_dQ, &ops); - - // dS.T @ Q - auto After_dSTranspose_Q = createdSQBMM( - b, h, s_q, s_kv, d, layout, &ops, - S_mul_dS_minus_O_dO_after_quan_dS, - qTensor, seqlenMNKTensor); - - // (dS.T * Q) * descale dS - auto After_dSTranspose_Q_before_dequan_Q = createScale( - After_dSTranspose_Q, // input tensor - descaledSTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, - 2009 /*UID offset*/); - - // (dS.T * Q) * descale dS * descale Q - auto After_dSTranspose_Q_before_quan_dK = createScale( - After_dSTranspose_Q_before_dequan_Q, // input tensor - descaleQTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, - 2010 /*UID offset*/); - - // (dS.T * Q) * descale dS * descale Q * scale dK - auto dK = createScaleWithOffset( - After_dSTranspose_Q_before_quan_dK, // input tensor - "scaledK", // scale tensor - layout, // qkv layout - CUDNN_DATA_FP8_E5M2, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - QKVRaggedOffsetTensorPtr, // ragged offset - "dK"); - - // Amax for dK - createAmax("amaxdK", After_dSTranspose_Q_before_quan_dK, &ops); - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, - allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, - CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_bprop: No config returned by the heuristics"); - } - - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; - - auto plan = get_plan(fa_bprop_cache, descriptor); - size_t wkspace_size = static_cast(plan.getWorkspaceSize()); - - // Exit to request upper level API to allocate memory if needed - if (workspace_ptr == nullptr) { - *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); - return; + // Otherwise, build the op_graph and the plan. Then update cache + std::vector all_ops; + std::vector ops; + + NVTE_CHECK(dropoutProbability != 1.0f, "Dropout probability cannot be 1.0"); + + int64_t raggedDim[4] = {b + 1, 1, 1, 1}; + int64_t raggedStride[4] = {1, 1, 1, 1}; + // Create offset tensors + auto QKVOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], + raggedDim, raggedStride, false, false); + auto ORaggedOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], + raggedDim, raggedStride, false, false); + + // Create shared ptrs to ragged offset tensors for multiple tensors + std::shared_ptr QKVRaggedOffsetTensorPtr = + std::make_shared(std::move(QKVOffsetTensor)); + std::shared_ptr ORaggedOffsetTensorPtr = + std::make_shared(std::move(ORaggedOffsetTensor)); + + // Create Q and K tensors that are used in different places + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_dim[4] = {b, h, s_kv, d}; + int64_t k_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + + auto qTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["Q"], q_dim, q_stride, + false, false, QKVRaggedOffsetTensorPtr); + auto kTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["K"], k_dim, k_stride, + false, false, QKVRaggedOffsetTensorPtr); + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + // Create attnScale tensor for multiple ops to use + auto attnScaleTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["AttnScale"], + scale_dim, scale_stride, false, true); // is by value + + // Create descale Q K dO dS global tensors since they are used in multiple places + auto descaleQTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleQ"], + scale_dim, scale_stride, false, false); + auto descaleKTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleK"], + scale_dim, scale_stride, false, false); + auto descaledOTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledO"], + scale_dim, scale_stride, false, false); + auto descaledSTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledS"], + scale_dim, scale_stride, false, false); + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + // Create MNK override tensor + auto seqlenMNKTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], + seqlen_dim, seqlen_stride, false, false); + + int64_t O_dim[4] = {b, h, s_q, d}; + int64_t O_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, O_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + // Create O and loss tensor + auto OTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["O"], O_dim, O_stride, + false, false, ORaggedOffsetTensorPtr); + // dO is used in multiple places and E5M2 + auto dOTensor = + tensor_create_with_offset(CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["dO"], O_dim, O_stride, + false, false, ORaggedOffsetTensorPtr); + + // Q * K.T + auto afterQKTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, qTensor, + kTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // QK.T * attn scale + auto AfterAttnScale_before_dequan_Q_tensor = + createScale(afterQKTensor, // input tensor + attnScaleTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + true, // scale is by value + &ops, 1999 /*UID offset*/); + + // QK.T * attn scale * dequant_Q + auto AfterAttnScale_before_dequan_K_tensor = + createScale(AfterAttnScale_before_dequan_Q_tensor, // input tensor + descaleQTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, 2000 /*UID offset*/); + + // QK.T * attn scale * dequant_Q * dequant_K + auto AfterAttnScale_tensor = + createScale(AfterAttnScale_before_dequan_K_tensor, // input tensor + descaleKTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, 2001 /*UID offset*/); + + auto beforeDropout_QKt_Tensor = + createSoftmaxBackward(b, h, s_q, s_kv, &ops, AfterAttnScale_tensor); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + // mask for the dropout. Used in different places + auto dropoutMaskTensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 200, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + + auto AfterDropout_before_quan_S = createDropoutBackward( + b, h, s_q, s_kv, dropoutProbability, &ops, beforeDropout_QKt_Tensor, dropoutMaskTensor); + + // After softmax * scale S -> fp8 input to next bmm with V + auto AfterMultiply = createScale(AfterDropout_before_quan_S, // input tensor + "scaleS", // scale tensor + tensorType, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // After softmax * dO + auto dVTensor_before_dequan_S = createSdOBMM(b, h, s_q, s_kv, d, tensorType, &ops, + AfterMultiply, dOTensor, seqlenMNKTensor); + + // O * dequant_S + auto dVTensor_before_dequan_dO = createScale(dVTensor_before_dequan_S, // input tensor + "descaleS", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // O * dequant_S * dequant_dO + auto dVTensor_before_quan_dV = createScale(dVTensor_before_dequan_dO, // input tensor + descaledOTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, 2002 /*UID offset*/); + + // O * dequant_S * dequant_dO * scale dV + auto dVTensor = createScaleWithOffset(dVTensor_before_quan_dV, // input tensor + "scaledV", // scale tensor + layout, // qkv layout + CUDNN_DATA_FP8_E5M2, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + QKVRaggedOffsetTensorPtr, // ragged offset + "dV" /*Output tensor name*/); + + // Amax for dV + createAmax("amaxdV", dVTensor_before_quan_dV, &ops); + + auto dS_before_dequan_dO_Tensor = + createdOVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, dOTensor, seqlenMNKTensor, + QKVRaggedOffsetTensorPtr); + + // dS * dequant_dO + auto dS_before_dequan_V = createScale(dS_before_dequan_dO_Tensor, // input tensor + descaledOTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, 2003 /*UID offset*/); + + // O * dequant_S * dequant_dV + auto dS_after_dequan = createScale(dS_before_dequan_V, // input tensor + "descaleV", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // RNG Multiply + auto beforeDropoutScale_dOVt_Tensor = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 350, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + // After dropout mask and scale + auto dS_after_dropout = + tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 351, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + + // Define the multiply mask descriptor + auto mulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto maskMul_op = binary_pw_op_create(dS_after_dequan, dropoutMaskTensor, + beforeDropoutScale_dOVt_Tensor, mulDesc); + + ops.push_back(std::move(maskMul_op)); + + // scale after dropout for dO and O chain + auto dropoutScale_dOVt_OdO_Tensor = + tensor_create(tensorType, tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], scale_dim, + scale_stride, false, true); // is by value + + // Create a multiply dropout scale Node + auto mul_dropout_scale_op = binary_pw_op_create( + beforeDropoutScale_dOVt_Tensor, dropoutScale_dOVt_OdO_Tensor, dS_after_dropout, mulDesc); + + ops.push_back(std::move(mul_dropout_scale_op)); + + // O * dequant_O + auto O_after_dequan_Tensor = createScale(OTensor, // input tensor + "descaleO", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // dO * dequant_dO + auto dO_after_dequan_Tensor = createScale(dOTensor, // input tensor + descaledOTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, 2004 /*UID offset*/); + + // row reduction sum[(dO * dequant_dO) * (O * dequant_O) * (1 - p)] + auto O_dO_after_rowsum = + createdOAndORowReductionChain(b, h, s_q, s_kv, d, layout, &ops, O_after_dequan_Tensor, + dO_after_dequan_Tensor, dropoutScale_dOVt_OdO_Tensor); + + // (dS_after_dropout - O_dO_after_rowsum) * AfterDropout_before_quan_S * attnScale + auto S_mul_dS_minus_O_dO = createBiasSubtractionSoftmaxMulChain( + b, h, s_q, s_kv, d, layout, &ops, dS_after_dropout, AfterDropout_before_quan_S, + O_dO_after_rowsum, attnScaleTensor); + + // S_mul_dS_minus_O_dO * scaledS + auto S_mul_dS_minus_O_dO_after_quan_dS = + createScale(S_mul_dS_minus_O_dO, // input tensor + "scaledS", // scale tensor + CUDNN_DATA_FP8_E5M2, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // Amax for dS + createAmax("amaxdS", S_mul_dS_minus_O_dO, &ops); + + // dS @ K + auto After_dS_K = createdSKBMM(b, h, s_q, s_kv, d, &ops, S_mul_dS_minus_O_dO_after_quan_dS, + kTensor, seqlenMNKTensor); + + // (dS * K) * descale dS + auto After_dS_K_before_dequan_K = createScale(After_dS_K, // input tensor + descaledSTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, 2006 /*UID offset*/); + + // (dS * K) * descale dS * descale K + auto After_dS_K_before_quan_dQ = createScale(After_dS_K_before_dequan_K, // input tensor + descaleKTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, 2007 /*UID offset*/); + + // (dS * K) * descale dS * descale K * scale dQ + auto dQ = createScaleWithOffset(After_dS_K_before_quan_dQ, // input tensor + "scaledQ", // scale tensor + layout, // qkv layout + CUDNN_DATA_FP8_E5M2, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + QKVRaggedOffsetTensorPtr, // ragged offset + "dQ"); + + // Amax for dQ + createAmax("amaxdQ", After_dS_K_before_quan_dQ, &ops); + + // dS.T @ Q + auto After_dSTranspose_Q = + createdSQBMM(b, h, s_q, s_kv, d, layout, &ops, S_mul_dS_minus_O_dO_after_quan_dS, qTensor, + seqlenMNKTensor); + + // (dS.T * Q) * descale dS + auto After_dSTranspose_Q_before_dequan_Q = + createScale(After_dSTranspose_Q, // input tensor + descaledSTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, 2009 /*UID offset*/); + + // (dS.T * Q) * descale dS * descale Q + auto After_dSTranspose_Q_before_quan_dK = + createScale(After_dSTranspose_Q_before_dequan_Q, // input tensor + descaleQTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, 2010 /*UID offset*/); + + // (dS.T * Q) * descale dS * descale Q * scale dK + auto dK = createScaleWithOffset(After_dSTranspose_Q_before_quan_dK, // input tensor + "scaledK", // scale tensor + layout, // qkv layout + CUDNN_DATA_FP8_E5M2, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + QKVRaggedOffsetTensorPtr, // ragged offset + "dK"); + + // Amax for dK + createAmax("amaxdK", After_dSTranspose_Q_before_quan_dK, &ops); + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); } - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); - - int32_t* qkv_ragged_offset = reinterpret_cast( - reinterpret_cast(workspace_ptr) + wkspace_size); - int32_t* o_ragged_offset = reinterpret_cast( - reinterpret_cast(workspace_ptr) - + wkspace_size + (b + 1) * sizeof(int32_t)); - int32_t* actual_seqlens_q = reinterpret_cast( - reinterpret_cast(workspace_ptr) - + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); - // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV - dim3 blockDims(128); - dim3 gridDims((b + blockDims.x)/blockDims.x); - cu_seqlens_to_offsets<<>>( - b, h, d, reinterpret_cast(devPtrcuSeqlensQ), - actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); - void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); - void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); - void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); - - std::set> data_ptrs; - float dropoutScale = 1.0f/(1.0f - dropoutProbability); - float dropoutScale_dOVt_OdO = 1.0f - dropoutProbability; - data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["V_TRANSPOSE"], devPtrV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dQ"], devPtrdQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dK"], devPtrdK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dV"], devPtrdV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dO"], devPtrdO)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["AttnScale"], &attnScale)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], - &dropoutScale_dOVt_OdO)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["M"], devPtrM)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["Z_INV"], devPtrZInv)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["O"], devPtrO)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaleK"], devPtrDescaleK)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaleV"], devPtrDescaleV)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaleS"], devPtrDescaleS)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaledS"], devPtrDescaledS)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaleO"], devPtrDescaleO)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["descaledO"], devPtrDescaledO)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["scaleS"], devPtrScaleS)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["scaledS"], devPtrScaledS)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["scaledQ"], devPtrScaledQ)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["scaledK"], devPtrScaledK)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["scaledV"], devPtrScaledV)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["amaxdS"], devPtrAmaxdS)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["amaxdQ"], devPtrAmaxdQ)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["amaxdK"], devPtrAmaxdK)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["amaxdV"], devPtrAmaxdV)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); - data_ptrs.emplace(std::pair( - tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(data_ptrs) - .build(); - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_, - plan.get_raw_desc(), - variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException& e) { - struct cudaDeviceProp prop; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); - - // This example is only for GH100 cards (cudnn Version >= 8900) - if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) - && (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH - || e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { - std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; - } else { - std::cout << "[ERROR] Exception " << e.what() << std::endl; + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_bprop: No config returned by the heuristics"); } + + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle_) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + cache.insert({descriptor, plan}); + return plan; + }; + + auto plan = get_plan(fa_bprop_cache, descriptor); + size_t wkspace_size = static_cast(plan.getWorkspaceSize()); + + // Exit to request upper level API to allocate memory if needed + if (workspace_ptr == nullptr) { + *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); + return; + } + + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); + + int32_t* qkv_ragged_offset = + reinterpret_cast(reinterpret_cast(workspace_ptr) + wkspace_size); + int32_t* o_ragged_offset = reinterpret_cast(reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * sizeof(int32_t)); + int32_t* actual_seqlens_q = reinterpret_cast( + reinterpret_cast(workspace_ptr) + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); + // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV + dim3 blockDims(128); + dim3 gridDims((b + blockDims.x) / blockDims.x); + cu_seqlens_to_offsets<<>>( + b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, + o_ragged_offset); + void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); + void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); + void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); + + std::set> data_ptrs; + float dropoutScale = 1.0f / (1.0f - dropoutProbability); + float dropoutScale_dOVt_OdO = 1.0f - dropoutProbability; + data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["V_TRANSPOSE"], devPtrV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dQ"], devPtrdQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dK"], devPtrdK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dV"], devPtrdV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dO"], devPtrdO)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["AttnScale"], &attnScale)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], + &dropoutScale_dOVt_OdO)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["M"], devPtrM)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["Z_INV"], devPtrZInv)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["O"], devPtrO)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleK"], devPtrDescaleK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleV"], devPtrDescaleV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleS"], devPtrDescaleS)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaledS"], devPtrDescaledS)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleO"], devPtrDescaleO)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["descaledO"], devPtrDescaledO)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["scaleS"], devPtrScaleS)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledS"], devPtrScaledS)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledQ"], devPtrScaledQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledK"], devPtrScaledK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledV"], devPtrScaledV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdS"], devPtrAmaxdS)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdQ"], devPtrAmaxdQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdK"], devPtrAmaxdK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdV"], devPtrAmaxdV)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); + data_ptrs.emplace( + std::pair(tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(data_ptrs) + .build(); + NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc())); + } catch (cudnn_frontend::cudnnException& e) { + struct cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + + // This example is only for GH100 cards (cudnn Version >= 8900) + if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) && + (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH || + e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { + std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; + } else { + std::cout << "[ERROR] Exception " << e.what() << std::endl; + } } } // fused attention FWD FP8 with FE 1.0+ -void fused_attn_fp8_fwd_impl_v1(int64_t b, int64_t h, int64_t hg, - int64_t s_q, int64_t s_kv, int64_t d, - bool is_training, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, - void* devPtrO, - void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, - void* devPtrAmaxO, void* devPtrAmaxS, - void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - cudnn_frontend::DataType_t fwd_tensor_type, - void* workspace, - size_t* workspace_size, - cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); - bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) - || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) - || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_dropout = (is_training && dropout_probability != 0.0f); - auto bias_b = b; - auto bias_h = h; - NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); - NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - NVTE_CHECK(~is_padding, - "FP8 fused attention does not support padding/padding_causal mask yet!"); - NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); - - try { - FADescriptor_v1 descriptor{b, h, - hg, s_q, - s_kv, d, - bias_b, bias_h, - scaling_factor, is_training, - dropout_probability, layout, - bias_type, mask_type, - fwd_tensor_type, fwd_tensor_type}; - - namespace fe = cudnn_frontend; - using graph_and_tensors = std::tuple, - std::shared_ptr, // Q - std::shared_ptr, // K - std::shared_ptr, // V - std::shared_ptr, // descale_q - std::shared_ptr, // descale_k - std::shared_ptr, // descale_v - std::shared_ptr, // descale_s - std::shared_ptr, // scale_s - std::shared_ptr, // scale_o - std::shared_ptr, // attn_scale - std::shared_ptr, // O - std::shared_ptr, // amax_s - std::shared_ptr, // amax_o - std::shared_ptr, // Stats - std::shared_ptr, // bias - std::shared_ptr, // seq_q - std::shared_ptr, // seq_kv - std::shared_ptr, // dropout_seed - std::shared_ptr >; // dropout_offset - - using CacheType = std::map; - static thread_local CacheType sdpa_fp8_fprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor) - -> graph_and_tensors { - // if hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto graph = it->second; - return graph; - } - - // otherwise, build the op_graph and the plan. Then update cache - auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(fwd_tensor_type) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - std::shared_ptr Q, K, V, attn_scale; - std::shared_ptr descale_q, descale_k, descale_v; - std::shared_ptr descale_s, scale_s, scale_o; - std::shared_ptr bias, seq_q, seq_kv; - std::shared_ptr dropout_seed, dropout_offset; - - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); - - fe::graph::SDPA_fp8_attributes sdpa_options; - sdpa_options = fe::graph::SDPA_fp8_attributes() - .set_name("sdpa_fp8") - .set_is_inference(false) - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale); - - // sdpa_options.set_alibi_mask(is_alibi); - // if (is_bias) { - // bias = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); - // sdpa_options.set_bias(bias); - // } - - // if (is_padding) { - // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_q") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_kv") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // sdpa_options.set_padding_mask(is_padding) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } - - // if (is_dropout) { - // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Seed") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Offset") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // sdpa_options.set_dropout( - // dropout_probability, dropout_seed, dropout_offset); - // } - - auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( - Q, K, V, descale_q, descale_k, descale_v, descale_s, - scale_s, scale_o, sdpa_options); - - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); - amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); - - std::tuple, // Q - std::shared_ptr, // K - std::shared_ptr, // V - std::shared_ptr, // descale_q - std::shared_ptr, // descale_k - std::shared_ptr, // descale_v - std::shared_ptr, // descale_s - std::shared_ptr, // scale_s - std::shared_ptr, // scale_o - std::shared_ptr, // attn_scale - std::shared_ptr, // O - std::shared_ptr, // amax_s - std::shared_ptr > // amax_o - key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, - descale_s, scale_s, scale_o, attn_scale, O, amax_s, amax_o); - auto Stats_tuple = std::make_tuple(Stats); - auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); - auto padding_tuple = is_padding ? - std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto dropout_tuple = is_dropout ? - std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); - - NVTE_CHECK_CUDNN_FE(mha_graph->validate()); - NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); - NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); - NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); - NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - - auto return_tuple = std::tuple_cat( - std::make_tuple(mha_graph), key_tensors_tuple, - Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); - cache.insert({descriptor, return_tuple}); - - return return_tuple; - }; - - auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s, - scale_s, scale_o, attn_scale, O, amax_s, amax_o, Stats, - bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph( - sdpa_fp8_fprop_cache, descriptor); - - auto plan_workspace_size = mha_graph->get_workspace_size(); - - // Exit to request upper level API to allocate memory if needed - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - if (workspace == nullptr) { - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - // Build variant pack - std::unordered_map, void*> variant_pack = { - {Q, devPtrQ}, - {K, devPtrK}, - {V, devPtrV}, - {descale_q, devPtrDescaleQ}, - {descale_k, devPtrDescaleK}, - {descale_v, devPtrDescaleV}, - {descale_s, devPtrDescaleS}, - {scale_s, devPtrScaleS}, - {scale_o, devPtrScaleO}, - {attn_scale, &scaling_factor}, - {O, devPtrO}, - {amax_s, devPtrAmaxS}, - {amax_o, devPtrAmaxO}, - {Stats, devPtrM}}; - - // if (is_bias) { - // variant_pack[bias] = devPtrBias; - // } - - // if (is_padding) { - // constexpr size_t nthreads_per_block = 128; - // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) - // + b * sizeof(int32_t); - // cu_seqlens_to_actual_seqlens<<>>( - // b, static_cast(devPtrCuSeqlensQ), - // static_cast(devPtrCuSeqlensKV), - // static_cast(devActualSeqlenQ), - // static_cast(devActualSeqlenKV)); - // variant_pack[seq_q] = devActualSeqlenQ; - // variant_pack[seq_kv] = devActualSeqlenKV; - // } - - // if (is_dropout) { - // variant_pack[dropout_seed] = devPtrDropoutSeed; - // variant_pack[dropout_offset] = devPtrDropoutOffset; - // } - NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); +void fused_attn_fp8_fwd_impl_v1( + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, + void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, + void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, + void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); + bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_dropout = (is_training && dropout_probability != 0.0f); + auto bias_b = b; + auto bias_h = h; + NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); + NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!"); + NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); + + try { + FADescriptor_v1 descriptor{b, + h, + hg, + s_q, + s_kv, + d, + bias_b, + bias_h, + scaling_factor, + is_training, + dropout_probability, + layout, + bias_type, + mask_type, + fwd_tensor_type, + fwd_tensor_type}; + + namespace fe = cudnn_frontend; + using graph_and_tensors = + std::tuple, + std::shared_ptr, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // descale_q + std::shared_ptr, // descale_k + std::shared_ptr, // descale_v + std::shared_ptr, // descale_s + std::shared_ptr, // scale_s + std::shared_ptr, // scale_o + std::shared_ptr, // attn_scale + std::shared_ptr, // O + std::shared_ptr, // amax_s + std::shared_ptr, // amax_o + std::shared_ptr, // Stats + std::shared_ptr, // bias + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // dropout_seed + std::shared_ptr>; // dropout_offset + + using CacheType = std::map; + static thread_local CacheType sdpa_fp8_fprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_graph = [&](CacheType& cache, const FADescriptor_v1& descriptor) -> graph_and_tensors { + // if hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto graph = it->second; + return graph; + } + + // otherwise, build the op_graph and the plan. Then update cache + auto mha_graph = std::make_shared(); + mha_graph->set_io_data_type(fwd_tensor_type) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + std::shared_ptr Q, K, V, attn_scale; + std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr descale_s, scale_s, scale_o; + std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr dropout_seed, dropout_offset; + + std::vector q_stride(4); + std::vector k_stride(4); + std::vector v_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride)); + + attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + + fe::graph::SDPA_fp8_attributes sdpa_options; + sdpa_options = fe::graph::SDPA_fp8_attributes() + .set_name("sdpa_fp8") + .set_is_inference(false) + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale); + + // sdpa_options.set_alibi_mask(is_alibi); + // if (is_bias) { + // bias = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("bias") + // .set_dim({bias_b, bias_h, s_q, s_kv}) + // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // sdpa_options.set_bias(bias); + // } + + // if (is_padding) { + // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_q") + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_kv") + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // sdpa_options.set_padding_mask(is_padding) + // .set_seq_len_q(seq_q) + // .set_seq_len_kv(seq_kv); + // } + + // if (is_dropout) { + // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("Seed") + // .set_dim({1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT64)); + // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("Offset") + // .set_dim({1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT64)); + // sdpa_options.set_dropout( + // dropout_probability, dropout_seed, dropout_offset); + // } + + auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( + Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); + + std::vector o_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}); + + std::tuple, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // descale_q + std::shared_ptr, // descale_k + std::shared_ptr, // descale_v + std::shared_ptr, // descale_s + std::shared_ptr, // scale_s + std::shared_ptr, // scale_o + std::shared_ptr, // attn_scale + std::shared_ptr, // O + std::shared_ptr, // amax_s + std::shared_ptr> // amax_o + key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, attn_scale, O, amax_s, amax_o); + auto Stats_tuple = std::make_tuple(Stats); + auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto padding_tuple = + is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) + : std::make_tuple(nullptr, nullptr); + + NVTE_CHECK_CUDNN_FE(mha_graph->validate()); + NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); + NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); + + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, + bias_tuple, padding_tuple, dropout_tuple); + cache.insert({descriptor, return_tuple}); + + return return_tuple; + }; + + auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, + attn_scale, O, amax_s, amax_o, Stats, bias, seq_q, seq_kv, dropout_seed, dropout_offset] = + get_graph(sdpa_fp8_fprop_cache, descriptor); + + auto plan_workspace_size = mha_graph->get_workspace_size(); + + // Exit to request upper level API to allocate memory if needed + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); + if (workspace == nullptr) { + *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + return; } + + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + // Build variant pack + std::unordered_map, void*> variant_pack = { + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {descale_q, devPtrDescaleQ}, + {descale_k, devPtrDescaleK}, + {descale_v, devPtrDescaleV}, + {descale_s, devPtrDescaleS}, + {scale_s, devPtrScaleS}, + {scale_o, devPtrScaleO}, + {attn_scale, &scaling_factor}, + {O, devPtrO}, + {amax_s, devPtrAmaxS}, + {amax_o, devPtrAmaxO}, + {Stats, devPtrM}}; + + // if (is_bias) { + // variant_pack[bias] = devPtrBias; + // } + + // if (is_padding) { + // constexpr size_t nthreads_per_block = 128; + // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + // + b * sizeof(int32_t); + // cu_seqlens_to_actual_seqlens<<>>( + // b, static_cast(devPtrCuSeqlensQ), + // static_cast(devPtrCuSeqlensKV), + // static_cast(devActualSeqlenQ), + // static_cast(devActualSeqlenKV)); + // variant_pack[seq_q] = devActualSeqlenQ; + // variant_pack[seq_kv] = devActualSeqlenKV; + // } + + // if (is_dropout) { + // variant_pack[dropout_seed] = devPtrDropoutSeed; + // variant_pack[dropout_offset] = devPtrDropoutOffset; + // } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); + } catch (cudnn_frontend::cudnnException& e) { + NVTE_ERROR(e.what()); + } } // fused attention BWD FP8 with FE 1.0+ -void fused_attn_fp8_bwd_impl_v1(int64_t b, int64_t h, int64_t hg, - int64_t s_q, int64_t s_kv, int64_t d, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, - void* devPtrO, void* devPtrdO, - void* devPtrdQ, void* devPtrdK, void* devPtrdV, - void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleO, void* devPtrDescaledO, - void* devPtrDescaleS, void* devPtrDescaledP, - void* devPtrScaleS, void* devPtrScaledP, - void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, - void* devPtrAmaxdP, - void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, - void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - cudnn_frontend::DataType_t fwd_tensor_type, - cudnn_frontend::DataType_t bwd_tensor_type, - void* workspace, - size_t* workspace_size, - cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); - bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) - || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) - || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_dropout = (dropout_probability != 0.0f); - auto bias_b = b; - auto bias_h = h; - NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); - NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - NVTE_CHECK(~is_padding, - "FP8 fused attention does not support padding/padding_causal mask yet!"); - NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); - - try { - FADescriptor_v1 descriptor{b, h, - hg, s_q, - s_kv, d, - bias_b, bias_h, - scaling_factor, true, - dropout_probability, layout, - bias_type, mask_type, - fwd_tensor_type, bwd_tensor_type}; - - namespace fe = cudnn_frontend; - using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats - std::shared_ptr, // dO - std::shared_ptr, // attn_scale - std::shared_ptr, // descale_q - std::shared_ptr, // descale_k - std::shared_ptr, // descale_v - std::shared_ptr, // descale_o - std::shared_ptr, // descale_dO - std::shared_ptr, // descale_s - std::shared_ptr, // descale_dP - std::shared_ptr, // scale_dQ - std::shared_ptr, // scale_dK - std::shared_ptr, // scale_dV - std::shared_ptr, // scale_s - std::shared_ptr, // scale_dP - std::shared_ptr, // dQ - std::shared_ptr, // dK - std::shared_ptr, // dV - std::shared_ptr, // amax_dQ - std::shared_ptr, // amax_dK - std::shared_ptr, // amax_dV - std::shared_ptr, // amax_dP - std::shared_ptr, // bias - std::shared_ptr, // dBias - std::shared_ptr, // seq_q - std::shared_ptr, // seq_kv - std::shared_ptr, // dropout_seed - std::shared_ptr >; // dropout_offset - - using CacheType = std::map; - static thread_local CacheType sdpa_fp8_bprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor) - -> graph_and_tensors { - // if hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto graph = it->second; - return graph; - } - - // otherwise, build the op_graph and the plan. Then update cache - auto mha_graph = std::make_shared(); - - mha_graph->set_io_data_type(fwd_tensor_type) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr descale_q, descale_k, descale_v; - std::shared_ptr descale_s, descale_o; - std::shared_ptr descale_dP, descale_dO; - std::shared_ptr scale_s, scale_dP; - std::shared_ptr scale_dQ, scale_dK, scale_dV; - std::shared_ptr bias, dBias, seq_q, seq_kv; - std::shared_ptr dropout_seed, dropout_offset; - - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); - dO = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); - descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); - descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); - - fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; - sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes() - .set_name("sdpa_fp8_backward") - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale); - - // sdpa_backward_options.set_alibi_mask(is_alibi); - - // if (is_bias) { - // bias = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); - // dBias = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("dBias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); - // sdpa_backward_options.set_bias(bias); - // // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] - // // are not supported for dbias calculation but they are - // // supported for forward bias calculation - // if ((bias_b == 1) && (bias_h == h)) { - // sdpa_backward_options.set_dbias(dBias); - // } - // } - - // if (is_padding) { - // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_q") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_kv") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // sdpa_backward_options.set_padding_mask(is_padding) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } - - // if (is_dropout) { - // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Seed") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Offset") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // sdpa_backward_options.set_dropout( - // dropout_probability, dropout_seed, dropout_offset); - // } - - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( - q, k, v, o, dO, stats, - descale_q, descale_k, descale_v, - descale_o, descale_dO, descale_s, descale_dP, - scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, - sdpa_backward_options); - - dQ->set_output(true) - .set_dim({b, h, s_q, d}) - .set_stride(q_stride); - dK->set_output(true) - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride); - dV->set_output(true) - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride); - amax_dQ->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - amax_dK->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - amax_dV->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - - dO->set_data_type(bwd_tensor_type); - dQ->set_data_type(bwd_tensor_type); - dK->set_data_type(bwd_tensor_type); - dV->set_data_type(bwd_tensor_type); - - std::tuple, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats - std::shared_ptr, // dO - std::shared_ptr, // attn_scale - std::shared_ptr, // descale_q - std::shared_ptr, // descale_k - std::shared_ptr, // descale_v - std::shared_ptr, // descale_o - std::shared_ptr, // descale_dO - std::shared_ptr, // descale_s - std::shared_ptr, // descale_dP - std::shared_ptr, // scale_dQ - std::shared_ptr, // scale_dK - std::shared_ptr, // scale_dV - std::shared_ptr, // scale_s - std::shared_ptr, // scale_dP - std::shared_ptr, // dQ - std::shared_ptr, // dK - std::shared_ptr, // dV - std::shared_ptr, // amax_dQ - std::shared_ptr, // amax_dK - std::shared_ptr, // amax_dV - std::shared_ptr > // amax_dP - key_tensors_tuple = std::make_tuple( - q, k, v, o, stats, dO, attn_scale, - descale_q, descale_k, descale_v, - descale_o, descale_dO, descale_s, descale_dP, - scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, - dQ, dK, dV, - amax_dQ, amax_dK, amax_dV, amax_dP); - auto bias_tuple = is_bias ? - std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); - auto padding_tuple = is_padding ? - std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto dropout_tuple = is_dropout ? - std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); - - NVTE_CHECK_CUDNN_FE(mha_graph->validate()); - NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); - NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); - NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); - NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - - auto return_tuple = std::tuple_cat( - std::make_tuple(mha_graph), key_tensors_tuple, - bias_tuple, padding_tuple, dropout_tuple); - cache.insert({descriptor, return_tuple}); - - return return_tuple; - }; - - auto [mha_graph, q, k, v, o, stats, dO, attn_scale, - descale_q, descale_k, descale_v, - descale_o, descale_dO, descale_s, descale_dP, - scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, - dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, - bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph( - sdpa_fp8_bprop_cache, descriptor); - - auto plan_workspace_size = mha_graph->get_workspace_size(); - - // Exit to request upper level API to allocate memory if needed - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - if (workspace == nullptr) { - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - // build variant pack - std::unordered_map, void*> variant_pack = { - {q, devPtrQ}, - {k, devPtrK}, - {v, devPtrV}, - {o, devPtrO}, - {stats, devPtrM}, - {dO, devPtrdO}, - {attn_scale, &scaling_factor}, - {descale_q, devPtrDescaleQ}, - {descale_k, devPtrDescaleK}, - {descale_v, devPtrDescaleV}, - {descale_o, devPtrDescaleO}, - {descale_dO, devPtrDescaledO}, - {descale_s, devPtrDescaleS}, - {descale_dP, devPtrDescaledP}, - {scale_s, devPtrScaleS}, - {scale_dQ, devPtrScaledQ}, - {scale_dK, devPtrScaledK}, - {scale_dV, devPtrScaledV}, - {scale_dP, devPtrScaledP}, - {dQ, devPtrdQ}, - {dK, devPtrdK}, - {dV, devPtrdV}, - {amax_dQ, devPtrAmaxdQ}, - {amax_dK, devPtrAmaxdK}, - {amax_dV, devPtrAmaxdV}, - {amax_dP, devPtrAmaxdP}, - }; - - // if (is_bias) { - // variant_pack[bias] = devPtrBias; - // if ((bias_b == 1) && (bias_h == h)) { - // variant_pack[dBias] = devPtrdBias; - // } else { - // variant_pack[dBias] = nullptr; - // } - // } - - // if (is_padding) { - // constexpr size_t nthreads_per_block = 128; - // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) - // + b * sizeof(int32_t); - // cu_seqlens_to_actual_seqlens<<>>( - // b, static_cast(devPtrCuSeqlensQ), - // static_cast(devPtrCuSeqlensKV), - // static_cast(devActualSeqlenQ), - // static_cast(devActualSeqlenKV)); - // variant_pack[seq_q] = devActualSeqlenQ; - // variant_pack[seq_kv] = devActualSeqlenKV; - // } - - // if (is_dropout) { - // variant_pack[dropout_seed] = devPtrDropoutSeed; - // variant_pack[dropout_offset] = devPtrDropoutOffset; - // } - - NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); +void fused_attn_fp8_bwd_impl_v1( + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, + float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, + void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, + void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, + void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, + void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, + cudnn_frontend::DataType_t bwd_tensor_type, void* workspace, size_t* workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); + bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_dropout = (dropout_probability != 0.0f); + auto bias_b = b; + auto bias_h = h; + NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); + NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!"); + NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); + + try { + FADescriptor_v1 descriptor{b, + h, + hg, + s_q, + s_kv, + d, + bias_b, + bias_h, + scaling_factor, + true, + dropout_probability, + layout, + bias_type, + mask_type, + fwd_tensor_type, + bwd_tensor_type}; + + namespace fe = cudnn_frontend; + using graph_and_tensors = + std::tuple, + std::shared_ptr, // q + std::shared_ptr, // k + std::shared_ptr, // v + std::shared_ptr, // o + std::shared_ptr, // stats + std::shared_ptr, // dO + std::shared_ptr, // attn_scale + std::shared_ptr, // descale_q + std::shared_ptr, // descale_k + std::shared_ptr, // descale_v + std::shared_ptr, // descale_o + std::shared_ptr, // descale_dO + std::shared_ptr, // descale_s + std::shared_ptr, // descale_dP + std::shared_ptr, // scale_dQ + std::shared_ptr, // scale_dK + std::shared_ptr, // scale_dV + std::shared_ptr, // scale_s + std::shared_ptr, // scale_dP + std::shared_ptr, // dQ + std::shared_ptr, // dK + std::shared_ptr, // dV + std::shared_ptr, // amax_dQ + std::shared_ptr, // amax_dK + std::shared_ptr, // amax_dV + std::shared_ptr, // amax_dP + std::shared_ptr, // bias + std::shared_ptr, // dBias + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // dropout_seed + std::shared_ptr>; // dropout_offset + + using CacheType = std::map; + static thread_local CacheType sdpa_fp8_bprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_graph = [&](CacheType& cache, const FADescriptor_v1& descriptor) -> graph_and_tensors { + // if hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto graph = it->second; + return graph; + } + + // otherwise, build the op_graph and the plan. Then update cache + auto mha_graph = std::make_shared(); + + mha_graph->set_io_data_type(fwd_tensor_type) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + std::shared_ptr q, k, v, o, dO, stats, attn_scale; + std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr descale_s, descale_o; + std::shared_ptr descale_dP, descale_dO; + std::shared_ptr scale_s, scale_dP; + std::shared_ptr scale_dQ, scale_dK, scale_dV; + std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr dropout_seed, dropout_offset; + + std::vector q_stride(4); + std::vector k_stride(4); + std::vector v_stride(4); + std::vector o_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride)); + k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride)); + v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride)); + o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride)); + dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride)); + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + + fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; + sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes() + .set_name("sdpa_fp8_backward") + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale); + + // sdpa_backward_options.set_alibi_mask(is_alibi); + + // if (is_bias) { + // bias = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("bias") + // .set_dim({bias_b, bias_h, s_q, s_kv}) + // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // dBias = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("dBias") + // .set_dim({bias_b, bias_h, s_q, s_kv}) + // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // sdpa_backward_options.set_bias(bias); + // // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] + // // are not supported for dbias calculation but they are + // // supported for forward bias calculation + // if ((bias_b == 1) && (bias_h == h)) { + // sdpa_backward_options.set_dbias(dBias); + // } + // } + + // if (is_padding) { + // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_q") + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_kv") + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // sdpa_backward_options.set_padding_mask(is_padding) + // .set_seq_len_q(seq_q) + // .set_seq_len_kv(seq_kv); + // } + + // if (is_dropout) { + // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("Seed") + // .set_dim({1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT64)); + // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("Offset") + // .set_dim({1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT64)); + // sdpa_backward_options.set_dropout( + // dropout_probability, dropout_seed, dropout_offset); + // } + + auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( + q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); + + dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); + dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); + dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); + amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + + dO->set_data_type(bwd_tensor_type); + dQ->set_data_type(bwd_tensor_type); + dK->set_data_type(bwd_tensor_type); + dV->set_data_type(bwd_tensor_type); + + std::tuple, // q + std::shared_ptr, // k + std::shared_ptr, // v + std::shared_ptr, // o + std::shared_ptr, // stats + std::shared_ptr, // dO + std::shared_ptr, // attn_scale + std::shared_ptr, // descale_q + std::shared_ptr, // descale_k + std::shared_ptr, // descale_v + std::shared_ptr, // descale_o + std::shared_ptr, // descale_dO + std::shared_ptr, // descale_s + std::shared_ptr, // descale_dP + std::shared_ptr, // scale_dQ + std::shared_ptr, // scale_dK + std::shared_ptr, // scale_dV + std::shared_ptr, // scale_s + std::shared_ptr, // scale_dP + std::shared_ptr, // dQ + std::shared_ptr, // dK + std::shared_ptr, // dV + std::shared_ptr, // amax_dQ + std::shared_ptr, // amax_dK + std::shared_ptr, // amax_dV + std::shared_ptr> // amax_dP + key_tensors_tuple = std::make_tuple( + q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); + auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto padding_tuple = + is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) + : std::make_tuple(nullptr, nullptr); + + NVTE_CHECK_CUDNN_FE(mha_graph->validate()); + NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); + NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); + + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, + padding_tuple, dropout_tuple); + cache.insert({descriptor, return_tuple}); + + return return_tuple; + }; + + auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, + dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, + dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + + auto plan_workspace_size = mha_graph->get_workspace_size(); + + // Exit to request upper level API to allocate memory if needed + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); + if (workspace == nullptr) { + *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + return; } + + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + // build variant pack + std::unordered_map, void*> variant_pack = { + {q, devPtrQ}, + {k, devPtrK}, + {v, devPtrV}, + {o, devPtrO}, + {stats, devPtrM}, + {dO, devPtrdO}, + {attn_scale, &scaling_factor}, + {descale_q, devPtrDescaleQ}, + {descale_k, devPtrDescaleK}, + {descale_v, devPtrDescaleV}, + {descale_o, devPtrDescaleO}, + {descale_dO, devPtrDescaledO}, + {descale_s, devPtrDescaleS}, + {descale_dP, devPtrDescaledP}, + {scale_s, devPtrScaleS}, + {scale_dQ, devPtrScaledQ}, + {scale_dK, devPtrScaledK}, + {scale_dV, devPtrScaledV}, + {scale_dP, devPtrScaledP}, + {dQ, devPtrdQ}, + {dK, devPtrdK}, + {dV, devPtrdV}, + {amax_dQ, devPtrAmaxdQ}, + {amax_dK, devPtrAmaxdK}, + {amax_dV, devPtrAmaxdV}, + {amax_dP, devPtrAmaxdP}, + }; + + // if (is_bias) { + // variant_pack[bias] = devPtrBias; + // if ((bias_b == 1) && (bias_h == h)) { + // variant_pack[dBias] = devPtrdBias; + // } else { + // variant_pack[dBias] = nullptr; + // } + // } + + // if (is_padding) { + // constexpr size_t nthreads_per_block = 128; + // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + // + b * sizeof(int32_t); + // cu_seqlens_to_actual_seqlens<<>>( + // b, static_cast(devPtrCuSeqlensQ), + // static_cast(devPtrCuSeqlensKV), + // static_cast(devActualSeqlenQ), + // static_cast(devActualSeqlenKV)); + // variant_pack[seq_q] = devActualSeqlenQ; + // variant_pack[seq_kv] = devActualSeqlenKV; + // } + + // if (is_dropout) { + // variant_pack[dropout_seed] = devPtrDropoutSeed; + // variant_pack[dropout_offset] = devPtrDropoutOffset; + // } + + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); + } catch (cudnn_frontend::cudnnException& e) { + NVTE_ERROR(e.what()); + } } #endif @@ -2552,33 +2322,27 @@ void fused_attn_fp8_bwd_impl_v1(int64_t b, int64_t h, int64_t hg, #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, - bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, - Tensor *input_output_S, - Tensor *output_O, - NVTETensorPack* Aux_CTX_Tensors, - const Tensor *cu_seqlens, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle) { +void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, + size_t head_dim, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O, + NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_QKV->data.dtype; void* devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = typeToSize(QKV_type) * num_attn_heads * head_dim; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = typeToSize(QKV_type) * head_dim; } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); + void* devPtrQ = static_cast(devPtrQKV); + void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); void* devPtrDescaleQ = input_QKV->scale_inv.dptr; void* devPtrDescaleK = input_QKV->scale_inv.dptr; void* devPtrDescaleV = input_QKV->scale_inv.dptr; @@ -2591,9 +2355,9 @@ void fused_attn_fp8_fwd_qkvpacked( void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 3; - Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; output_M->data.dtype = DType::kFloat32; @@ -2604,9 +2368,9 @@ void fused_attn_fp8_fwd_qkvpacked( output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; @@ -2618,79 +2382,55 @@ void fused_attn_fp8_fwd_qkvpacked( void* devPtrScaleS = input_output_S->scale.dptr; void* devPtrDescaleS = input_output_S->scale_inv.dptr; - void* devPtrcuSeqlens = reinterpret_cast( - reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); + void* devPtrcuSeqlens = + reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); + void* devPtrDropoutSeed = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) - || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, - devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_fwd_impl_v1( + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, + devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, + devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, - is_training, attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, - devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + fused_attn::fused_attn_fp8_fwd_impl( + batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, + qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, + devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, + devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else { NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { - workspace->data.shape = { workspace_size }; + workspace->data.shape = {workspace_size}; workspace->data.dtype = DType::kByte; return; } } else if (workspace_size == 0) { - workspace->data.shape = { 1 }; + workspace->data.shape = {1}; workspace->data.dtype = DType::kByte; return; } } // fused attention BWD FP8 with packed QKV void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, - const Tensor *input_O, - const Tensor *input_dO, - const Tensor *input_M, - const Tensor *input_ZInv, - const Tensor *input_S, - Tensor *input_output_dP, - const Tensor *output_dQKV, - const Tensor *cu_seqlens, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle) { + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, + const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, + const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_QKV->data.dtype; const DType dQKV_type = output_dQKV->data.dtype; @@ -2698,13 +2438,13 @@ void fused_attn_fp8_bwd_qkvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = typeToSize(QKV_type) * num_attn_heads * head_dim; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = typeToSize(QKV_type) * head_dim; } - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); + void* devPtrQ = devPtrQKV; + void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); void* devPtrDescaleQ = input_QKV->scale_inv.dptr; void* devPtrDescaleK = input_QKV->scale_inv.dptr; void* devPtrDescaleV = input_QKV->scale_inv.dptr; @@ -2723,10 +2463,10 @@ void fused_attn_fp8_bwd_qkvpacked( void* devPtrScaledP = input_output_dP->scale.dptr; void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); + void* devPtrdQKV = output_dQKV->data.dptr; + void* devPtrdQ = devPtrdQKV; + void* devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); + void* devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); void* devPtrAmaxdQ = output_dQKV->amax.dptr; void* devPtrAmaxdK = output_dQKV->amax.dptr; void* devPtrAmaxdV = output_dQKV->amax.dptr; @@ -2734,103 +2474,74 @@ void fused_attn_fp8_bwd_qkvpacked( void* devPtrScaledK = output_dQKV->scale.dptr; void* devPtrScaledV = output_dQKV->scale.dptr; - void* devPtrcuSeqlens = reinterpret_cast( - reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); + void* devPtrcuSeqlens = + reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); + void* devPtrDropoutSeed = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) - || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, devPtrdO, - devPtrdQ, devPtrdK, devPtrdV, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, - devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_bwd_impl_v1( + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, - attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, devPtrdO, - devPtrdQ, devPtrdK, devPtrdV, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, - devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + fused_attn::fused_attn_fp8_bwd_impl( + batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, + devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, + devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, + devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else { NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { - workspace->data.shape = { workspace_size }; + workspace->data.shape = {workspace_size}; workspace->data.dtype = DType::kByte; return; } } else if (workspace_size == 0) { - workspace->data.shape = { 1 }; + workspace->data.shape = {1}; workspace->data.dtype = DType::kByte; return; } } // fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, - const Tensor *input_KV, - Tensor *input_output_S, - Tensor *output_O, - NVTETensorPack* Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle) { +void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor* input_Q, + const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O, + NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, + const Tensor* cu_seqlens_kv, const Tensor* rng_state, + Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_Q->data.dtype; void* devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; + void* devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = typeToSize(QKV_type) * head_dim; } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); + void* devPtrK = devPtrKV; + void* devPtrV = static_cast(static_cast(devPtrKV) + stride); void* devPtrDescaleQ = input_Q->scale_inv.dptr; void* devPtrDescaleK = input_KV->scale_inv.dptr; void* devPtrDescaleV = input_KV->scale_inv.dptr; @@ -2843,9 +2554,9 @@ void fused_attn_fp8_fwd_kvpacked( void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 3; - Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; @@ -2856,9 +2567,9 @@ void fused_attn_fp8_fwd_kvpacked( output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; @@ -2870,99 +2581,74 @@ void fused_attn_fp8_fwd_kvpacked( void* devPtrScaleS = input_output_S->scale.dptr; void* devPtrDescaleS = input_output_S->scale_inv.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast( - reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = reinterpret_cast( - reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); + void* devPtrcuSeqlensQ = + reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); + void* devPtrcuSeqlensKV = + reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); + void* devPtrDropoutSeed = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) - || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_fwd_impl_v1( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, + devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, + devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, - is_training, attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + fused_attn::fused_attn_fp8_fwd_impl( + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, + p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, + devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); } else { NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { - workspace->data.shape = { workspace_size }; + workspace->data.shape = {workspace_size}; workspace->data.dtype = DType::kByte; return; } } else if (workspace_size == 0) { - workspace->data.shape = { 1 }; + workspace->data.shape = {1}; workspace->data.dtype = DType::kByte; return; } } // fused attention BWD FP8 with packed KV void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, - const Tensor *input_KV, - const Tensor *input_O, - const Tensor *input_dO, - const Tensor *input_M, - const Tensor *input_ZInv, - const Tensor *input_S, - Tensor *input_output_dP, - const Tensor *output_dQ, - const Tensor *output_dKV, - const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle) { + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, + const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, + const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q, + const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_Q->data.dtype; const DType dQKV_type = output_dQ->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; + void* devPtrQ = input_Q->data.dptr; + void* devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = typeToSize(QKV_type) * head_dim; } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); + void* devPtrK = devPtrKV; + void* devPtrV = static_cast(static_cast(devPtrKV) + stride); void* devPtrDescaleQ = input_Q->scale_inv.dptr; void* devPtrDescaleK = input_KV->scale_inv.dptr; void* devPtrDescaleV = input_KV->scale_inv.dptr; @@ -2981,10 +2667,10 @@ void fused_attn_fp8_bwd_kvpacked( void* devPtrScaledP = input_output_dP->scale.dptr; void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdKV = output_dKV->data.dptr; - void *devPtrdK = devPtrdKV; - void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); + void* devPtrdQ = output_dQ->data.dptr; + void* devPtrdKV = output_dKV->data.dptr; + void* devPtrdK = devPtrdKV; + void* devPtrdV = static_cast(static_cast(devPtrdKV) + stride); void* devPtrAmaxdQ = output_dQ->amax.dptr; void* devPtrAmaxdK = output_dKV->amax.dptr; void* devPtrAmaxdV = output_dKV->amax.dptr; @@ -2992,93 +2678,63 @@ void fused_attn_fp8_bwd_kvpacked( void* devPtrScaledK = output_dKV->scale.dptr; void* devPtrScaledV = output_dKV->scale.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast( - reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = reinterpret_cast( - reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); + void* devPtrcuSeqlensQ = + reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); + void* devPtrcuSeqlensKV = + reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); + void* devPtrDropoutSeed = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) - || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, devPtrdO, - devPtrdQ, devPtrdK, devPtrdV, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_bwd_impl_v1( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, - attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, devPtrdO, - devPtrdQ, devPtrdK, devPtrdV, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + fused_attn::fused_attn_fp8_bwd_impl( + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, + qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, + devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, + devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, + devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, + devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else { NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { - workspace->data.shape = { workspace_size }; + workspace->data.shape = {workspace_size}; workspace->data.dtype = DType::kByte; return; } } else if (workspace_size == 0) { - workspace->data.shape = { 1 }; + workspace->data.shape = {1}; workspace->data.dtype = DType::kByte; return; } } // fused attention FWD FP8 with separate Q, K, V -void fused_attn_fp8_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, - const Tensor *input_K, - const Tensor *input_V, - Tensor *input_output_S, - Tensor *output_O, - NVTETensorPack* Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle) { +void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, + const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, + NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, + const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -3095,9 +2751,9 @@ void fused_attn_fp8_fwd( void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 3; - Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; @@ -3108,9 +2764,9 @@ void fused_attn_fp8_fwd( output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; @@ -3122,88 +2778,63 @@ void fused_attn_fp8_fwd( void* devPtrScaleS = input_output_S->scale.dptr; void* devPtrDescaleS = input_output_S->scale_inv.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast( - reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = reinterpret_cast( - reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); + void* devPtrcuSeqlensQ = + reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); + void* devPtrcuSeqlensKV = + reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); + void* devPtrDropoutSeed = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) - || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_fwd_impl_v1( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, + devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, + devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, - is_training, attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + fused_attn::fused_attn_fp8_fwd_impl( + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, + p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, + devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); } else { NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { - workspace->data.shape = { workspace_size }; + workspace->data.shape = {workspace_size}; workspace->data.dtype = DType::kByte; return; } } else if (workspace_size == 0) { - workspace->data.shape = { 1 }; + workspace->data.shape = {1}; workspace->data.dtype = DType::kByte; return; } } // fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, - const Tensor *input_K, - const Tensor *input_V, - const Tensor *input_O, - const Tensor *input_dO, - const Tensor *input_M, - const Tensor *input_ZInv, - const Tensor *input_S, - Tensor *input_output_dP, - const Tensor *output_dQ, - const Tensor *output_dK, - const Tensor *output_dV, - const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle) { +void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, + const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, + const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, + const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, + const Tensor* output_dK, const Tensor* output_dV, + const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -3236,72 +2867,51 @@ void fused_attn_fp8_bwd( void* devPtrScaledK = output_dQ->scale.dptr; void* devPtrScaledV = output_dQ->scale.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast( - reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = reinterpret_cast( - reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = reinterpret_cast( - reinterpret_cast(rng_state->data.dptr) + 1); + void* devPtrcuSeqlensQ = + reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); + void* devPtrcuSeqlensKV = + reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); + void* devPtrDropoutSeed = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) - || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, devPtrdO, - devPtrdQ, devPtrdK, devPtrdV, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_bwd_impl_v1( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, - attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, - devPtrO, devPtrdO, - devPtrdQ, devPtrdK, devPtrdV, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + fused_attn::fused_attn_fp8_bwd_impl( + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, + qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, + devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, + devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, + devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, + devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else { NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { - workspace->data.shape = { workspace_size }; + workspace->data.shape = {workspace_size}; workspace->data.dtype = DType::kByte; return; } } else if (workspace_size == 0) { - workspace->data.shape = { 1 }; + workspace->data.shape = {1}; workspace->data.dtype = DType::kByte; return; } diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 83da421561fd088bc754bf9e29c347b29d728d8a..55830d3cda11028a50727b81a550670c83a1d4dc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -8,127 +8,74 @@ * \brief Functions for fused attention for FP8 with seqlen <= 512 */ -#include "transformer_engine/transformer_engine.h" #include "transformer_engine/fused_attn.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, - bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, - Tensor *input_output_S, - Tensor *output_O, - NVTETensorPack* Aux_CTX_Tensors, - const Tensor *cu_seqlens, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle); +void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, + size_t head_dim, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle); // fused attention BWD FP8 with packed QKV void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, - const Tensor *input_O, - const Tensor *input_dO, - const Tensor *input_M, - const Tensor *input_ZInv, - const Tensor *input_S, - Tensor *input_output_dP, - const Tensor *output_dQKV, - const Tensor *cu_seqlens, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle); + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, + const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, + const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); // fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, - const Tensor *input_KV, - Tensor *input_output_S, - Tensor *output_O, - NVTETensorPack* Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle); +void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor *input_Q, + const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); // fused attention BWD FP8 with packed KV void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, - const Tensor *input_KV, - const Tensor *input_O, - const Tensor *input_dO, - const Tensor *input_M, - const Tensor *input_ZInv, - const Tensor *input_S, - Tensor *input_output_dP, - const Tensor *output_dQ, - const Tensor *output_dKV, - const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle); + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, + const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle); // fused attention FWD FP8 with separate Q, K, V -void fused_attn_fp8_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - Tensor *input_output_S, - Tensor *output_O, - NVTETensorPack* Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle); +void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); // fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, - const Tensor *input_dO, - const Tensor *input_M, - const Tensor *input_ZInv, - const Tensor *input_S, - Tensor *input_output_dP, - const Tensor *output_dQ, - const Tensor *output_dK, - const Tensor *output_dV, - const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, - const Tensor *rng_state, - Tensor *workspace, - cudaStream_t stream, - cudnnHandle_t handle); +void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, + const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, + const Tensor *output_dK, const Tensor *output_dV, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 057615e6af831ce815bc18ae8ed1a6529b75730c..73bb5a7279d0d9cb52f1f5281b1045554789e1c2 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -4,8 +4,8 @@ * See LICENSE for license information. ************************************************************************/ -#include "transformer_engine/fused_attn.h" #include "../common.h" +#include "transformer_engine/fused_attn.h" #include "utils.h" namespace transformer_engine { @@ -14,242 +14,239 @@ namespace fused_attn { using namespace transformer_engine; // get matrix strides based on matrix type -void generateMatrixStrides( - int64_t b, int64_t h, - int64_t s_q, int64_t s_kv, - int64_t d, int64_t* strideA, - NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { - constexpr int batch_dim_idx = 0; - constexpr int head_dim_idx = 1; - constexpr int seqlen_dim_idx = 2; - constexpr int hidden_dim_idx = 3; +void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + constexpr int seqlen_dim_idx = 2; + constexpr int hidden_dim_idx = 3; - constexpr int seqlen_transpose_dim_idx = 3; - constexpr int hidden_transpose_dim_idx = 2; + constexpr int seqlen_transpose_dim_idx = 3; + constexpr int hidden_transpose_dim_idx = 2; - constexpr int seqlen_q_dim_idx = 2; - constexpr int seqlen_kv_dim_idx = 3; + constexpr int seqlen_q_dim_idx = 2; + constexpr int seqlen_kv_dim_idx = 3; - switch (layout) { - case NVTE_QKV_Layout::NVTE_SB3HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strideA[batch_dim_idx] = 3 * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = b * 3 * h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = 3 * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_transpose_dim_idx] = b * 3 * h * d; - strideA[hidden_transpose_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { - strideA[batch_dim_idx] = h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = b * h * d; - strideA[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBH3D: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strideA[batch_dim_idx] = 3 * h * d; - strideA[head_dim_idx] = 3 * d; - strideA[seqlen_dim_idx] = b * 3 * h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = 3 * h * d; - strideA[head_dim_idx] = 3 * d; - strideA[seqlen_transpose_dim_idx] = b * 3 * h * d; - strideA[hidden_transpose_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { - strideA[batch_dim_idx] = h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = b * h * d; - strideA[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strideA[batch_dim_idx] = 2 * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = b * 2 * h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = 2 * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_transpose_dim_idx] = b * 2 * h * d; - strideA[hidden_transpose_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { - strideA[batch_dim_idx] = h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = b * h * d; - strideA[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: - if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strideA[batch_dim_idx] = 2 * h * d; - strideA[head_dim_idx] = 2 * d; - strideA[seqlen_dim_idx] = b * 2 * h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = 2 * h * d; - strideA[head_dim_idx] = 2 * d; - strideA[seqlen_transpose_dim_idx] = b * 2 * h * d; - strideA[hidden_transpose_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { - strideA[batch_dim_idx] = h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = b * h * d; - strideA[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { - strideA[batch_dim_idx] = h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = b * h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_transpose_dim_idx] = b * h * d; - strideA[hidden_transpose_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BS3HD: - case NVTE_QKV_Layout::NVTE_T3HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strideA[batch_dim_idx] = s_q * 3 * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = 3 * h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = s_q * 3 * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_transpose_dim_idx] = 3 * h * d; - strideA[hidden_transpose_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { - strideA[batch_dim_idx] = s_q * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; - strideA[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSH3D: - case NVTE_QKV_Layout::NVTE_TH3D: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strideA[batch_dim_idx] = s_q * 3 * h * d; - strideA[head_dim_idx] = 3 * d; - strideA[seqlen_dim_idx] = 3 * h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = s_q * 3 * h * d; - strideA[head_dim_idx] = 3 * d; - strideA[seqlen_transpose_dim_idx] = 3 * h * d; - strideA[hidden_transpose_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { - strideA[batch_dim_idx] = s_q * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; - strideA[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: - case NVTE_QKV_Layout::NVTE_THD_T2HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strideA[batch_dim_idx] = s_kv * 2 * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = 2 * h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = s_kv * 2 * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_transpose_dim_idx] = 2 * h * d; - strideA[hidden_transpose_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { - strideA[batch_dim_idx] = s_q * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; - strideA[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: - case NVTE_QKV_Layout::NVTE_THD_TH2D: - if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strideA[batch_dim_idx] = s_kv * 2 * h * d; - strideA[head_dim_idx] = 2 * d; - strideA[seqlen_dim_idx] = 2 * h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = s_kv * 2 * h * d; - strideA[head_dim_idx] = 2 * d; - strideA[seqlen_transpose_dim_idx] = 2 * h * d; - strideA[hidden_transpose_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { - strideA[batch_dim_idx] = s_q * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; - strideA[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_THD_THD_THD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { - strideA[batch_dim_idx] = s_q * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; - strideA[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) - || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_transpose_dim_idx] = h * d; - strideA[hidden_transpose_dim_idx] = 1; - } - break; - } + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * 3 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = b * 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = 3 * d; + strideA[seqlen_dim_idx] = b * 3 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = 3 * d; + strideA[seqlen_transpose_dim_idx] = b * 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = 2 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * 2 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = 2 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = b * 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = 2 * h * d; + strideA[head_dim_idx] = 2 * d; + strideA[seqlen_dim_idx] = b * 2 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = 2 * h * d; + strideA[head_dim_idx] = 2 * d; + strideA[seqlen_transpose_dim_idx] = b * 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = b * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_q * 3 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_q * 3 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_q * 3 * h * d; + strideA[head_dim_idx] = 3 * d; + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_q * 3 * h * d; + strideA[head_dim_idx] = 3 * d; + strideA[seqlen_transpose_dim_idx] = 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_kv * 2 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = 2 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_kv * 2 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_kv * 2 * h * d; + strideA[head_dim_idx] = 2 * d; + strideA[seqlen_dim_idx] = 2 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_kv * 2 * h * d; + strideA[head_dim_idx] = 2 * d; + strideA[seqlen_transpose_dim_idx] = 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[hidden_transpose_dim_idx] = 1; + } + break; + } - if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { - strideA[seqlen_kv_dim_idx] = 1; - strideA[seqlen_q_dim_idx] = s_kv; - strideA[head_dim_idx] = s_q * s_kv; - strideA[batch_dim_idx] = h * s_q * s_kv; - } + if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { + strideA[seqlen_kv_dim_idx] = 1; + strideA[seqlen_q_dim_idx] = s_kv; + strideA[head_dim_idx] = s_q * s_kv; + strideA[batch_dim_idx] = h * s_q * s_kv; + } } bool allowAllConfig(cudnnBackendDescriptor_t engine_config) { @@ -257,12 +254,11 @@ bool allowAllConfig(cudnnBackendDescriptor_t engine_config) { return false; } -cudnn_frontend::Tensor tensor_create( - cudnnDataType_t type, int64_t id, - int64_t const * dim, int64_t const * stride, - bool is_virtual, bool is_value) { +cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id, int64_t const *dim, + int64_t const *stride, bool is_virtual, bool is_value) { int nbDims = 4; - auto tensor_created = cudnn_frontend::TensorBuilder() + auto tensor_created = + cudnn_frontend::TensorBuilder() .setDim(nbDims, dim) .setStride(nbDims, stride) .setId(id) @@ -275,12 +271,11 @@ cudnn_frontend::Tensor tensor_create( } cudnn_frontend::Tensor tensor_create_with_offset( - cudnnDataType_t type, int64_t id, - int64_t const * dim, int64_t const * stride, - bool is_virtual, bool is_value, - std::shared_ptr raggedOffset) { + cudnnDataType_t type, int64_t id, int64_t const *dim, int64_t const *stride, bool is_virtual, + bool is_value, std::shared_ptr raggedOffset) { int nbDims = 4; - auto tensor_created = cudnn_frontend::TensorBuilder() + auto tensor_created = + cudnn_frontend::TensorBuilder() .setDim(nbDims, dim) .setStride(nbDims, stride) .setId(id) @@ -293,62 +288,58 @@ cudnn_frontend::Tensor tensor_create_with_offset( return tensor_created; } -cudnn_frontend::PointWiseDesc pw_desc_create( - cudnnDataType_t type, cudnnPointwiseMode_t mode) { - auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder() - .setMode(mode) - .setComputeType(type) - .build(); +cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type, cudnnPointwiseMode_t mode) { + auto pw_desc_created = + cudnn_frontend::PointWiseDescBuilder().setMode(mode).setComputeType(type).build(); return pw_desc_created; } -cudnn_frontend::Operation unary_pw_op_create( - cudnn_frontend::Tensor const &xDesc, - cudnn_frontend::Tensor const &yDesc, - cudnn_frontend::PointWiseDesc const &pwDesc) { - auto pw_op_created = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(xDesc) - .setyDesc(yDesc) - .setpwDesc(pwDesc) - .build(); +cudnn_frontend::Operation unary_pw_op_create(cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc) { + auto pw_op_created = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(xDesc) + .setyDesc(yDesc) + .setpwDesc(pwDesc) + .build(); return pw_op_created; } -cudnn_frontend::Operation binary_pw_op_create( - cudnn_frontend::Tensor const &xDesc, - cudnn_frontend::Tensor const &bDesc, - cudnn_frontend::Tensor const &yDesc, - cudnn_frontend::PointWiseDesc const &pwDesc) { - auto pw_op_created = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(xDesc) - .setbDesc(bDesc) - .setyDesc(yDesc) - .setpwDesc(pwDesc) - .build(); +cudnn_frontend::Operation binary_pw_op_create(cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &bDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc) { + auto pw_op_created = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(xDesc) + .setbDesc(bDesc) + .setyDesc(yDesc) + .setpwDesc(pwDesc) + .build(); return pw_op_created; } -cudnn_frontend::Operation ternary_pw_op_create( - cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc, - cudnn_frontend::Tensor const &tDesc, cudnn_frontend::Tensor const &yDesc, - cudnn_frontend::PointWiseDesc const &pwDesc) { - auto pw_op_created = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(xDesc) - .setbDesc(bDesc) - .settDesc(tDesc) - .setyDesc(yDesc) - .setpwDesc(pwDesc) - .build(); +cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &bDesc, + cudnn_frontend::Tensor const &tDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc) { + auto pw_op_created = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(xDesc) + .setbDesc(bDesc) + .settDesc(tDesc) + .setyDesc(yDesc) + .setpwDesc(pwDesc) + .build(); return pw_op_created; } // convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q -__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, - int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, - int32_t *qkv_ragged_offset, int32_t *o_ragged_offset) { +__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, int32_t *cu_seqlens_q, + int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, + int32_t *o_ragged_offset) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < b) { actual_seqlens_q[tid] = cu_seqlens_q[tid + 1] - cu_seqlens_q[tid]; @@ -360,10 +351,9 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, } // convert cu_seqlens to actual_seqlens -__global__ void cu_seqlens_to_actual_seqlens(size_t b, - int32_t const * const q_cu_seqlens, - int32_t const * const kv_cu_seqlens, - int32_t *q_seqlens, int32_t *kv_seqlens) { +__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens, + int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, + int32_t *kv_seqlens) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < b) { q_seqlens[tid] = q_cu_seqlens[tid + 1] - q_cu_seqlens[tid]; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 11da5cf56c6c0a82106903e60cabe254c93481b2..b139280ec479a91dfcd90cba45c8853185a441d8 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -7,9 +7,6 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ -#include "transformer_engine/fused_attn.h" -#include "transformer_engine/transformer_engine.h" - #include #include #include @@ -17,56 +14,52 @@ #include #include +#include "transformer_engine/fused_attn.h" +#include "transformer_engine/transformer_engine.h" + namespace transformer_engine { namespace fused_attn { using namespace transformer_engine; enum NVTE_QKV_Matrix { - NVTE_Q_Matrix = 0, // queries - NVTE_K_Matrix = 1, // keys - NVTE_K_Matrix_Transpose = 2, // keys transposed - NVTE_V_Matrix = 3, // values - NVTE_V_Matrix_Transpose = 4, // value matrix transposed - NVTE_S_Matrix = 5, // output of GEMM1 - NVTE_O_Matrix = 6, // final output + NVTE_Q_Matrix = 0, // queries + NVTE_K_Matrix = 1, // keys + NVTE_K_Matrix_Transpose = 2, // keys transposed + NVTE_V_Matrix = 3, // values + NVTE_V_Matrix_Transpose = 4, // value matrix transposed + NVTE_S_Matrix = 5, // output of GEMM1 + NVTE_O_Matrix = 6, // final output }; -void generateMatrixStrides( - int64_t b, int64_t h, - int64_t s_q, int64_t s_kv, - int64_t d, int64_t* strideA, - NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); +void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); bool allowAllConfig(cudnnBackendDescriptor_t engine_config); -cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id, - int64_t const *dim, - int64_t const *stride, - bool is_virtual, bool is_value); +cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id, int64_t const *dim, + int64_t const *stride, bool is_virtual, bool is_value); cudnn_frontend::Tensor tensor_create_with_offset( - cudnnDataType_t type, int64_t id, - int64_t const * dim, int64_t const * stride, - bool is_virtual, bool is_value, - std::shared_ptr raggedOffset); + cudnnDataType_t type, int64_t id, int64_t const *dim, int64_t const *stride, bool is_virtual, + bool is_value, std::shared_ptr raggedOffset); -cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type, - cudnnPointwiseMode_t mode); +cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type, cudnnPointwiseMode_t mode); -cudnn_frontend::Operation unary_pw_op_create( - cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &yDesc, - cudnn_frontend::PointWiseDesc const &pwDesc); +cudnn_frontend::Operation unary_pw_op_create(cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc); -cudnn_frontend::Operation binary_pw_op_create( - cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc, - cudnn_frontend::Tensor const &yDesc, - cudnn_frontend::PointWiseDesc const &pwDesc); +cudnn_frontend::Operation binary_pw_op_create(cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &bDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc); -cudnn_frontend::Operation ternary_pw_op_create( - cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc, - cudnn_frontend::Tensor const &tDesc, cudnn_frontend::Tensor const &yDesc, - cudnn_frontend::PointWiseDesc const &pwDesc); +cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &bDesc, + cudnn_frontend::Tensor const &tDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc); struct FADescriptor { std::int64_t b; @@ -84,15 +77,11 @@ struct FADescriptor { bool use_workspace_opt; bool operator<(const FADescriptor &rhs) const { - return std::tie(b, h, s_q, s_kv, d, - attnScale, isTraining, dropoutProbability, - layout, mask_type, bias_type, tensor_type, use_workspace_opt) - < std::tie( - rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, - rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, - rhs.mask_type, rhs.bias_type, - rhs.tensor_type, rhs.use_workspace_opt); + return std::tie(b, h, s_q, s_kv, d, attnScale, isTraining, dropoutProbability, layout, + mask_type, bias_type, tensor_type, use_workspace_opt) < + std::tie(rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, rhs.attnScale, rhs.isTraining, + rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.bias_type, + rhs.tensor_type, rhs.use_workspace_opt); } }; @@ -115,27 +104,22 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { - return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, - attnScale, isTraining, dropoutProbability, - layout, mask_type, bias_type, fwd_tensor_type, bwd_tensor_type) - < std::tie( - rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, - rhs.bias_b, rhs.bias_h, - rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, - rhs.mask_type, rhs.bias_type, - rhs.fwd_tensor_type, rhs.bwd_tensor_type); + return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining, + dropoutProbability, layout, mask_type, bias_type, fwd_tensor_type, + bwd_tensor_type) < + std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h, + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + rhs.mask_type, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); } }; -__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, - int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, - int32_t *qkv_ragged_offset, int32_t *o_ragged_offset); +__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, int32_t *cu_seqlens_q, + int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, + int32_t *o_ragged_offset); -__global__ void cu_seqlens_to_actual_seqlens(size_t b, - int32_t const * const q_cu_seqlens, - int32_t const * const kv_cu_seqlens, - int32_t *q_seqlens, int32_t *kv_seqlens); +__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens, + int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, + int32_t *kv_seqlens); } // namespace fused_attn @@ -144,22 +128,21 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) class cudnnExecutionPlanManager { public: - static cudnnExecutionPlanManager &Instance() { - static thread_local cudnnExecutionPlanManager instance; - return instance; - } + static cudnnExecutionPlanManager &Instance() { + static thread_local cudnnExecutionPlanManager instance; + return instance; + } - cudnnHandle_t GetCudnnHandle() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { cudnnCreate(&handle_); }); - return handle_; - } + cudnnHandle_t GetCudnnHandle() { + static thread_local std::once_flag flag; + std::call_once(flag, [&] { cudnnCreate(&handle_); }); + return handle_; + } - ~cudnnExecutionPlanManager() { - } + ~cudnnExecutionPlanManager() {} private: - cudnnHandle_t handle_ = nullptr; + cudnnHandle_t handle_ = nullptr; }; } // namespace transformer_engine diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 14f76175dc9c01f9cd4f5ad6f4860c0ff2135011..e7cf940a574b48f2c3dab15e610ff51bc3f94a6d 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -14,11 +14,11 @@ namespace transformer_engine { template -__device__ void fused_rope_block_forward( - const scalar_t *src, const float *freqs, scalar_t *dst, - const int offset_block, const int offset_block_dst, const int h, - const int d, const int d2, const int stride_h, const int stride_d, - const int o_stride_h, const int o_stride_d) { +__device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, const int stride_h, + const int stride_d, const int o_stride_h, + const int o_stride_d) { int s_id = blockIdx.x; #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { @@ -30,10 +30,9 @@ __device__ void fused_rope_block_forward( int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; float v_src_rotate = (d_id + d2 / 2 < d2) - ? -static_cast(src[offset_src + (d2 / 2) * stride_d]) - : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); - dst[offset_dst] = - v_src * v_cos + v_src_rotate * v_sin; + ? -static_cast(src[offset_src + (d2 / 2) * stride_d]) + : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } @@ -45,34 +44,31 @@ __device__ void fused_rope_block_forward( int offset_head_dst = offset_block_dst + h_id * o_stride_h; #pragma unroll for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = - src[offset_head + d_id * stride_d]; + dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; } } } } template -__device__ void fused_rope_block_backward( - const scalar_t *src, const float *freqs, scalar_t *dst, - const int offset_block, const int offset_block_dst, const int h, - const int d, const int d2, const int stride_h, const int stride_d, - const int o_stride_h, const int o_stride_d) { +__device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int stride_h, const int stride_d, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x; #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos = cosf(freqs[s_id * d2 + d_id]); - float v_sin = (d_id + d2 / 2 < d2) - ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) - : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); + float v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) + : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); #pragma unroll for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; - float v_src_rotate = (d_id + d2 / 2 < d2) - ? src[offset_src + (d2 / 2) * stride_d] - : src[offset_src + (d2 / 2 - d2) * stride_d]; + float v_src_rotate = (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } @@ -85,279 +81,251 @@ __device__ void fused_rope_block_backward( int offset_head_dst = offset_block_dst + h_id * o_stride_h; #pragma unroll for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = - src[offset_head + d_id * stride_d]; + dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; } } } } template -__global__ void fused_rope_forward_kernel( - const scalar_t *src, const float *freqs, scalar_t *dst, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, + const int h, const int d, const int d2, + const int stride_s, const int stride_b, + const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, - d, d2, stride_h, stride_d, o_stride_h, o_stride_d); + fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, + stride_d, o_stride_h, o_stride_d); } template -__global__ void fused_rope_backward_kernel( - const scalar_t *src, const float *freqs, scalar_t *dst, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, + const int h, const int d, const int d2, + const int stride_s, const int stride_b, + const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, - d, d2, stride_h, stride_d, o_stride_h, o_stride_d); + fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, + stride_d, o_stride_h, o_stride_d); } template -__global__ void fused_rope_thd_forward_kernel( - const scalar_t *src, const int *cu_seqlens, const float *freqs, - scalar_t *dst, const int h, const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens, + const float *freqs, scalar_t *dst, const int h, + const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, + const int o_stride_t, const int o_stride_h, + const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int t_id = s_id + cu_seqlens[b_id]; if (t_id >= cu_seqlens[b_id + 1]) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, - d, d2, stride_h, stride_d, o_stride_h, o_stride_d); + fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, + stride_d, o_stride_h, o_stride_d); } template -__global__ void fused_rope_thd_backward_kernel( - const scalar_t *src, const int *cu_seqlens, const float *freqs, - scalar_t *dst, const int h, const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, + const float *freqs, scalar_t *dst, const int h, + const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, + const int o_stride_t, const int o_stride_h, + const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int t_id = s_id + cu_seqlens[b_id]; if (t_id >= cu_seqlens[b_id + 1]) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, - d, d2, stride_h, stride_d, o_stride_h, o_stride_d); + fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, + stride_d, o_stride_h, o_stride_d); } template -void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, - scalar_t *output, const int s, const int b, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { +void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_forward_kernel<<>>( - input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, - o_stride_s, o_stride_b, o_stride_h, o_stride_d); + input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template -void fused_rope_backward_launcher(const scalar_t *output_grads, - const float *freqs, scalar_t *input_grads, - const int s, const int b, const int h, - const int d, const int d2, const int stride_s, - const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { +void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs, + scalar_t *input_grads, const int s, const int b, const int h, + const int d, const int d2, const int stride_s, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_backward_kernel<<>>( - output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, - stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d); + output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d, + o_stride_s, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template -void fused_rope_thd_forward_launcher( - const scalar_t *input, const int *cu_seqlens, const float *freqs, - scalar_t *output, const int max_s, const int b, const int h, const int d, - const int d2, const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { +void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens, + const float *freqs, scalar_t *output, const int max_s, + const int b, const int h, const int d, const int d2, + const int stride_t, const int stride_h, const int stride_d, + const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); - fused_rope_thd_forward_kernel<<>>( - input, cu_seqlens, freqs, output, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d); + fused_rope_thd_forward_kernel<<>>(input, cu_seqlens, freqs, output, h, + d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template -void fused_rope_thd_backward_launcher( - const scalar_t *output_grads, const int *cu_seqlens, const float *freqs, - scalar_t *input_grads, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { +void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, + const float *freqs, scalar_t *input_grads, const int max_s, + const int b, const int h, const int d, const int d2, + const int stride_t, const int stride_h, const int stride_d, + const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_thd_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t, - stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d); + output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } -void fused_rope_forward(const Tensor &input, const Tensor &freqs, - Tensor *output, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, - const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, +void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, - fused_rope_forward_launcher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(freqs.data.dptr), - reinterpret_cast(output->data.dptr), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, stream);); + fused_rope_forward_launcher(reinterpret_cast(input.data.dptr), + reinterpret_cast(freqs.data.dptr), + reinterpret_cast(output->data.dptr), s, b, h, d, d2, + stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, stream);); } -void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, - Tensor *input_grads, const int s, const int b, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { +void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, - fused_rope_backward_launcher( - reinterpret_cast(output_grads.data.dptr), - reinterpret_cast(freqs.data.dptr), - reinterpret_cast(input_grads->data.dptr), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, stream);); + fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(freqs.data.dptr), + reinterpret_cast(input_grads->data.dptr), s, b, h, d, + d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, stream);); } -void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, - const Tensor &freqs, Tensor *output, - const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, +void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, + Tensor *output, const int max_s, const int b, const int h, const int d, + const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, - fused_rope_thd_forward_launcher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), - reinterpret_cast(freqs.data.dptr), - reinterpret_cast(output->data.dptr), max_s, b, h, d, d2, - stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, - stream);); + fused_rope_thd_forward_launcher(reinterpret_cast(input.data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), + reinterpret_cast(freqs.data.dptr), + reinterpret_cast(output->data.dptr), max_s, b, h, + d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, + o_stride_d, stream);); } -void fused_rope_thd_backward(const Tensor &output_grads, - const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *input_grads, const int max_s, const int b, - const int h, const int d, const int d2, - const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { +void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, + const Tensor &freqs, Tensor *input_grads, const int max_s, const int b, + const int h, const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, - fused_rope_thd_backward_launcher( - reinterpret_cast(output_grads.data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), - reinterpret_cast(freqs.data.dptr), - reinterpret_cast(input_grads->data.dptr), max_s, b, h, d, - d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, - stream);); + fused_rope_thd_backward_launcher(reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), + reinterpret_cast(freqs.data.dptr), + reinterpret_cast(input_grads->data.dptr), max_s, + b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, + o_stride_h, o_stride_d, stream);); } } // end namespace transformer_engine -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, - NVTETensor output, const int s, const int b, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_forward); using namespace transformer_engine; fused_rope_forward(*reinterpret_cast(input), - *reinterpret_cast(freqs), - reinterpret_cast(output), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, - o_stride_b, o_stride_h, o_stride_d, stream); + *reinterpret_cast(freqs), reinterpret_cast(output), + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, stream); } -void nvte_fused_rope_backward(const NVTETensor output_grads, - const NVTETensor freqs, NVTETensor input_grads, - const int s, const int b, const int h, - const int d, const int d2, const int stride_s, - const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, + NVTETensor input_grads, const int s, const int b, const int h, + const int d, const int d2, const int stride_s, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; fused_rope_backward(*reinterpret_cast(output_grads), *reinterpret_cast(freqs), - reinterpret_cast(input_grads), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, - o_stride_b, o_stride_h, o_stride_d, stream); + reinterpret_cast(input_grads), s, b, h, d, d2, stride_s, stride_b, + stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); } -void nvte_fused_rope_thd_forward(const NVTETensor input, - const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, - const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { +void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor output, const int max_s, + const int b, const int h, const int d, const int d2, + const int stride_t, const int stride_h, const int stride_d, + const int o_stride_t, const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_thd_forward); using namespace transformer_engine; - fused_rope_thd_forward(*reinterpret_cast(input), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), - reinterpret_cast(output), max_s, b, h, d, d2, - stride_t, stride_h, stride_d, o_stride_t, o_stride_h, - o_stride_d, stream); + fused_rope_thd_forward( + *reinterpret_cast(input), *reinterpret_cast(cu_seqlens), + *reinterpret_cast(freqs), reinterpret_cast(output), max_s, b, h, d, + d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); } -void nvte_fused_rope_thd_backward( - const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, - const int b, const int h, const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { +void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor input_grads, const int max_s, + const int b, const int h, const int d, const int d2, + const int stride_t, const int stride_h, const int stride_d, + const int o_stride_t, const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_thd_backward); using namespace transformer_engine; fused_rope_thd_backward(*reinterpret_cast(output_grads), *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), - reinterpret_cast(input_grads), max_s, b, h, - d, d2, stride_t, stride_h, stride_d, o_stride_t, - o_stride_h, o_stride_d, stream); + reinterpret_cast(input_grads), max_s, b, h, d, d2, stride_t, + stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); } diff --git a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu index 648f9ab6b19729a1f0899ed831a5ea8aeca6ce37..841edcf0435942cc694809cd569c1e0243c9d749 100644 --- a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -5,57 +5,55 @@ ************************************************************************/ #include -#include - -#include -#include -#include -#include - #include #include #include #include - +#include #include + +#include +#include +#include +#include + #include "../common.h" -#include "../utils.cuh" #include "../util/logging.h" - +#include "../utils.cuh" namespace transformer_engine { -template +template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); -template<> +template <> __device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { - *dst = *src; + *dst = *src; } -template<> +template <> __device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { - *((uint64_t*) dst) = *((uint64_t*) src); // NOLINT(*) + *((uint64_t *)dst) = *((uint64_t *)src); // NOLINT(*) } -template<> +template <> __device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { - *dst = *src; + *dst = *src; } -template<> +template <> __device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { - *((uint64_t*) dst) = *((uint64_t*) src); // NOLINT(*) + *((uint64_t *)dst) = *((uint64_t *)src); // NOLINT(*) } -template<> +template <> __device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { - *dst = *src; + *dst = *src; } -template<> +template <> __device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { - *((uint32_t*) dst) = *((uint32_t*) src); // NOLINT(*) + *((uint32_t *)dst) = *((uint32_t *)src); // NOLINT(*) } template @@ -63,378 +61,322 @@ __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> __device__ __inline__ void copy_zero_vector(bf16 *dst) { - *dst = 0.0f; + *dst = 0.0f; } template <> __device__ __inline__ void copy_zero_vector(bf16 *dst) { - *((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*) + *((float2 *)dst) = make_float2(0.0f, 0.0f); // NOLINT(*) } template <> __device__ __inline__ void copy_zero_vector(fp16 *dst) { - *dst = 0.0f; + *dst = 0.0f; } template <> __device__ __inline__ void copy_zero_vector(fp16 *dst) { - *((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*) + *((float2 *)dst) = make_float2(0.0f, 0.0f); // NOLINT(*) } - -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } }; -template +template __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_ROWS; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_ROWS; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* * Extended softmax (from native aten pytorch) with the following additional features * 1) input scaling * 2) implicit causal masking - * + * * works for all cases: * k > q * k < q * k = q - * + * * where: * microbatches = batches * attn_heads * query_seq_len * rows = query_seq_len * cols = key_seq_len */ template -__global__ void scaled_aligned_causal_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - const int microbatches, - const int rows, - const int cols -) { - // 1) WARP_WIDTH must match the value of warp_size - // 2) WARP_ROWS must match the value of rows_per_warp - // of the dispatch_scaled_aligned_causal_masked_softmax_forward method. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_WIDTH = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two - : THREADS_PER_WARP; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH; - constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS; - const int col = threadIdx.x * ELEMENTS_PER_LDG_STG; - - const size_t thread_offset = global_row_idx * cols + col; - - src += thread_offset; - dst += thread_offset; - - // load data from global memory into registers WITH scaling - acc_t elements[WARP_ROWS][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - - #pragma unroll - for (int w = 0; w < WARP_ROWS; ++w) { - const int microbatch = global_row_idx + w; - const int i = microbatch % rows; // local row index of attention matrix - const int masked_elements = i + cols - rows + 1; - - if (microbatch >= microbatches) { - break; - } - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - const int j = col + it * WARP_WIDTH; - const int itr_idx = w * cols + it * WARP_WIDTH; - - if (j < masked_elements) { - copy_vector(temp_data, src + itr_idx); - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (j + element < masked_elements) { - elements[w][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[w][it + element] = (acc_t)( -10'000 ); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[w][it + element] = (acc_t)( -10'000 ); - } - } - } +__global__ void scaled_aligned_causal_masked_softmax_warp_forward(output_t *dst, const input_t *src, + const acc_t scale, + const int microbatches, + const int rows, const int cols) { + // 1) WARP_WIDTH must match the value of warp_size + // 2) WARP_ROWS must match the value of rows_per_warp + // of the dispatch_scaled_aligned_causal_masked_softmax_forward method. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_WIDTH = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH; + constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS; + const int col = threadIdx.x * ELEMENTS_PER_LDG_STG; + + const size_t thread_offset = global_row_idx * cols + col; + + src += thread_offset; + dst += thread_offset; + + // load data from global memory into registers WITH scaling + acc_t elements[WARP_ROWS][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + const int i = microbatch % rows; // local row index of attention matrix + const int masked_elements = i + cols - rows + 1; + + if (microbatch >= microbatches) { + break; } - // compute max_value - acc_t max_value[WARP_ROWS]; - #pragma unroll - for (int w = 0; w < WARP_ROWS; ++w) { - max_value[w] = elements[w][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[w] = - (max_value[w] > elements[w][it]) ? max_value[w] : elements[w][it]; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < masked_elements) { + copy_vector(temp_data, src + itr_idx); +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (j + element < masked_elements) { + elements[w][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[w][it + element] = (acc_t)(-10'000); + } } - } - warp_reduce(max_value); - - acc_t sum[WARP_ROWS] { 0.0f }; - #pragma unroll - for (int w = 0; w < WARP_ROWS; ++w) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[w][it] = expf((elements[w][it] - max_value[w])); - sum[w] += elements[w][it]; + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[w][it + element] = (acc_t)(-10'000); } + } + } + } + + // compute max_value + acc_t max_value[WARP_ROWS]; +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + max_value[w] = elements[w][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[w] = (max_value[w] > elements[w][it]) ? max_value[w] : elements[w][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_ROWS]{0.0f}; +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[w][it] = expf((elements[w][it] - max_value[w])); + sum[w] += elements[w][it]; + } + } + warp_reduce(sum); + + output_t out[ELEMENTS_PER_LDG_STG]{0.0f}; +// store result +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + const int i = microbatch % rows; + const int masked_elements = i + cols - rows + 1; + + // out of Attention matrix bounds (rows) + if (microbatch >= microbatches) { + break; } - warp_reduce(sum); - - output_t out[ELEMENTS_PER_LDG_STG] { 0.0f }; - // store result - #pragma unroll - for (int w = 0; w < WARP_ROWS; ++w) { - const int microbatch = global_row_idx + w; - const int i = microbatch % rows; - const int masked_elements = i + cols - rows + 1; - - // out of Attention matrix bounds (rows) - if (microbatch >= microbatches) { - break; - } - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - const int j = col + it * WARP_WIDTH; // index of the first column - const int itr_idx = w * cols + it * WARP_WIDTH; - - if (j < masked_elements) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (j + element < masked_elements) { - out[element] = elements[w][it + element] / sum[w]; - } else { - out[element] = (output_t)( 0.0f ); - } - } - copy_vector(dst + itr_idx, out); - } else if (j < cols) { - copy_zero_vector(dst + itr_idx); - } else { - break; - } +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; // index of the first column + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < masked_elements) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (j + element < masked_elements) { + out[element] = elements[w][it + element] / sum[w]; + } else { + out[element] = (output_t)(0.0f); + } } + copy_vector(dst + itr_idx, out); + } else if (j < cols) { + copy_zero_vector(dst + itr_idx); + } else { + break; + } } + } } - template __global__ void scaled_aligned_causal_masked_softmax_warp_backward( - output_t *gradInput, - const input_t *grad, - const input_t *softmax_output, - const acc_t scale, - const int microbatches, - const int rows, - const int cols -) { - // 1) WARP_WIDTH must match the value of warp_size - // 2) WARP_ROWS must match the value of rows_per_warp - // of the dispatch_scaled_aligned_causal_masked_softmax_forward method. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_WIDTH = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two - : THREADS_PER_WARP; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH; - constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS; - const int col = threadIdx.x * ELEMENTS_PER_LDG_STG; - - const size_t thread_offset = global_row_idx * cols + col; - - grad += thread_offset; - softmax_output += thread_offset; - gradInput += thread_offset; - - // load data from global memory into registers - acc_t grad_reg[WARP_ROWS][WARP_ITERATIONS] { 0.0f }; - acc_t softmax_output_reg[WARP_ROWS][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - - #pragma unroll - for (int w = 0; w < WARP_ROWS; ++w) { - const int microbatch = global_row_idx + w; - const int i = microbatch % rows; // local row index of attention matrix - const int masked_elements = i + cols - rows + 1; - - if (microbatch >= microbatches) { - break; - } - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - const int j = col + it * WARP_WIDTH; // index of the first column - const int itr_idx = w * cols + it * WARP_WIDTH; - - if (j < masked_elements) { - copy_vector(temp_grad, grad + itr_idx); - copy_vector(temp_output, softmax_output + itr_idx); - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (j + element < masked_elements) { - softmax_output_reg[w][it + element] = (acc_t)temp_output[element]; - grad_reg[w][it + element] = - (acc_t)temp_grad[element] * softmax_output_reg[w][it + element]; - } - } - } - } + output_t *gradInput, const input_t *grad, const input_t *softmax_output, const acc_t scale, + const int microbatches, const int rows, const int cols) { + // 1) WARP_WIDTH must match the value of warp_size + // 2) WARP_ROWS must match the value of rows_per_warp + // of the dispatch_scaled_aligned_causal_masked_softmax_forward method. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_WIDTH = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH; + constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS; + const int col = threadIdx.x * ELEMENTS_PER_LDG_STG; + + const size_t thread_offset = global_row_idx * cols + col; + + grad += thread_offset; + softmax_output += thread_offset; + gradInput += thread_offset; + + // load data from global memory into registers + acc_t grad_reg[WARP_ROWS][WARP_ITERATIONS]{0.0f}; + acc_t softmax_output_reg[WARP_ROWS][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + const int i = microbatch % rows; // local row index of attention matrix + const int masked_elements = i + cols - rows + 1; + + if (microbatch >= microbatches) { + break; } - acc_t sum[WARP_ROWS]; - #pragma unroll - for (int w = 0; w < WARP_ROWS; ++w) { - sum[w] = grad_reg[w][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[w] += grad_reg[w][it]; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; // index of the first column + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < masked_elements) { + copy_vector(temp_grad, grad + itr_idx); + copy_vector(temp_output, softmax_output + itr_idx); +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (j + element < masked_elements) { + softmax_output_reg[w][it + element] = (acc_t)temp_output[element]; + grad_reg[w][it + element] = + (acc_t)temp_grad[element] * softmax_output_reg[w][it + element]; + } } + } + } + } + + acc_t sum[WARP_ROWS]; +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + sum[w] = grad_reg[w][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[w] += grad_reg[w][it]; } + } - warp_reduce(sum); + warp_reduce(sum); - // store result - #pragma unroll - for (int w = 0; w < WARP_ROWS; ++w) { - const int microbatch = global_row_idx + w; - if (microbatch >= microbatches) { - break; - } +// store result +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + if (microbatch >= microbatches) { + break; + } - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - const int j = col + it * WARP_WIDTH; // index of the first column - const int itr_idx = w * cols + it * WARP_WIDTH; - - if (j < cols) { - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[w][it + element] - - softmax_output_reg[w][it + element] * sum[w])); - } - copy_vector(gradInput + itr_idx, out); - } +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; // index of the first column + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < cols) { + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[w][it + element] - + softmax_output_reg[w][it + element] * sum[w])); } + copy_vector(gradInput + itr_idx, out); + } } + } } -template +template void call_kernel_scaled_aligned_causal_masked_softmax_forward( - dim3 grid_size, - dim3 block_size, - const int shmem_size, - cudaStream_t stream, - output_t *dst, - const input_t *src, - const acc_t scale, - const int microbatches, - const int query_seq_len, - const int key_seq_len -) { - scaled_aligned_causal_masked_softmax_warp_forward - <<>>( - dst, src, scale, microbatches, query_seq_len, key_seq_len); + dim3 grid_size, dim3 block_size, const int shmem_size, cudaStream_t stream, output_t *dst, + const input_t *src, const acc_t scale, const int microbatches, const int query_seq_len, + const int key_seq_len) { + scaled_aligned_causal_masked_softmax_warp_forward + <<>>(dst, src, scale, microbatches, query_seq_len, + key_seq_len); } -template +template void call_kernel_scaled_aligned_causal_masked_softmax_backward( - dim3 grid_size, - dim3 block_size, - const int shmem_size, - cudaStream_t stream, - output_t *gradInput, - const input_t *grad, - const input_t *output, - const acc_t scale, - const int microbatches, - const int query_seq_len, - const int key_seq_len -) { - scaled_aligned_causal_masked_softmax_warp_backward - <<>>( - gradInput, grad, output, scale, microbatches, query_seq_len, key_seq_len); + dim3 grid_size, dim3 block_size, const int shmem_size, cudaStream_t stream, output_t *gradInput, + const input_t *grad, const input_t *output, const acc_t scale, const int microbatches, + const int query_seq_len, const int key_seq_len) { + scaled_aligned_causal_masked_softmax_warp_backward + <<>>(gradInput, grad, output, scale, microbatches, + query_seq_len, key_seq_len); } -template +template struct FunctionWrapper { - using ForwardType = std::function< - void( - dim3 grid_size, - dim3 block_size, - const int shmem_size, - cudaStream_t stream, - output_t *dst, - const input_t *src, - const acc_t scale, - const int microbatches, - const int query_seq_len, - const int key_seq_len - ) - >; - using BackwardType = std::function< - void( - dim3 grid_size, - dim3 block_size, - const int shmem_size, - cudaStream_t stream, - output_t *gradInput, - const input_t *grad, - const input_t *output, - const acc_t scale, - const int microbatches, - const int query_seq_len, - const int key_seq_len - ) - >; + using ForwardType = + std::function; + using BackwardType = std::function; }; - constexpr int MIN_SUPPORTED_POWER = 4; constexpr int MAX_SUPPORTED_POWER = 14; constexpr int MIN_POWER = MIN_SUPPORTED_POWER - 1; @@ -444,228 +386,183 @@ constexpr int MAX_POWER = MAX_SUPPORTED_POWER + 1; // i.e. "MAX_POWER" defined above. template struct CompileTimeLoopForward { - using ForwardFuncType = typename FunctionWrapper::ForwardType; - static void populate(std::array* arr) { - CompileTimeLoopForward::populate(arr); - (*arr)[log2_elements] = &call_kernel_scaled_aligned_causal_masked_softmax_forward< - output_t, input_t, acc_t, log2_elements>; - } + using ForwardFuncType = typename FunctionWrapper::ForwardType; + static void populate(std::array *arr) { + CompileTimeLoopForward::populate(arr); + (*arr)[log2_elements] = + &call_kernel_scaled_aligned_causal_masked_softmax_forward; + } }; template struct CompileTimeLoopForward { - using ForwardFuncType = typename FunctionWrapper::ForwardType; - static void populate(std::array* arr) { - (*arr)[MIN_POWER] = nullptr; - } + using ForwardFuncType = typename FunctionWrapper::ForwardType; + static void populate(std::array *arr) { (*arr)[MIN_POWER] = nullptr; } }; template struct CompileTimeLoopBackward { - using BackwardFuncType = typename FunctionWrapper::BackwardType; - static void populate(std::array* arr) { - CompileTimeLoopBackward::populate(arr); - (*arr)[log2_elements] = &call_kernel_scaled_aligned_causal_masked_softmax_backward< - output_t, input_t, acc_t, log2_elements>; - } + using BackwardFuncType = typename FunctionWrapper::BackwardType; + static void populate(std::array *arr) { + CompileTimeLoopBackward::populate(arr); + (*arr)[log2_elements] = + &call_kernel_scaled_aligned_causal_masked_softmax_backward; + } }; template struct CompileTimeLoopBackward { - using BackwardFuncType = typename FunctionWrapper::BackwardType; - static void populate(std::array* arr) { - (*arr)[MIN_POWER] = nullptr; - } + using BackwardFuncType = typename FunctionWrapper::BackwardType; + static void populate(std::array *arr) { + (*arr)[MIN_POWER] = nullptr; + } }; -template -void dispatch_scaled_aligned_causal_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - cudaStream_t stream -) { - NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); - - if (key_seq_len == 0) { - return; - } - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - // This value must match the WARP_WIDTH constexpr - // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. - int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two - : THREADS_PER_WARP; - - // This value must match the WARP_ROWS constexpr - // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. - int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = threads_per_block / warp_width; - int microbatches_per_block = warps_per_block * microbatches_per_warp; - int microbatches = batches * attn_heads * query_seq_len; - int blocks = DIVUP(microbatches, microbatches_per_block); - - dim3 block_size(warp_width, warps_per_block); - dim3 grid_size(blocks); - - // create an array of pointers to functions - using ForwardFuncType = typename FunctionWrapper::ForwardType; - static std::array forwardFunctionArray; - static bool is_initialized = false; - if (!is_initialized) { - CompileTimeLoopForward::populate( - &forwardFunctionArray); - is_initialized = true; - } - // Call the corresponding kernel - forwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, dst, src, scale, - microbatches, query_seq_len, key_seq_len); +template +void dispatch_scaled_aligned_causal_masked_softmax_forward(output_t *dst, const input_t *src, + const input_t scale, int query_seq_len, + int key_seq_len, int batches, + int attn_heads, cudaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + + if (key_seq_len == 0) { + return; + } + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_WIDTH constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_ROWS constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = threads_per_block / warp_width; + int microbatches_per_block = warps_per_block * microbatches_per_warp; + int microbatches = batches * attn_heads * query_seq_len; + int blocks = DIVUP(microbatches, microbatches_per_block); + + dim3 block_size(warp_width, warps_per_block); + dim3 grid_size(blocks); + + // create an array of pointers to functions + using ForwardFuncType = typename FunctionWrapper::ForwardType; + static std::array forwardFunctionArray; + static bool is_initialized = false; + if (!is_initialized) { + CompileTimeLoopForward::populate( + &forwardFunctionArray); + is_initialized = true; + } + // Call the corresponding kernel + forwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, dst, src, scale, + microbatches, query_seq_len, key_seq_len); } -template +template void dispatch_scaled_aligned_causal_masked_softmax_backward( - output_t *grad_input, - const input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - cudaStream_t stream -) { - NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); - - if (key_seq_len == 0) { - return; - } - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - // This value must match the WARP_WIDTH constexpr - // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. - int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; - - // This value must match the WARP_ROWS constexpr - // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. - int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = threads_per_block / warp_width; - int microbatches_per_block = warps_per_block * microbatches_per_warp; - int microbatches = batches * attn_heads * query_seq_len; - int blocks = DIVUP(microbatches, microbatches_per_block); - - dim3 block_size(warp_width, warps_per_block); - dim3 grid_size(blocks); - - // create an array of pointers to functions - using BackwardFuncType = typename FunctionWrapper::BackwardType; - static std::array backwardFunctionArray; - static bool is_initialized = false; - if (!is_initialized) { - CompileTimeLoopBackward::populate( - &backwardFunctionArray); - is_initialized = true; - } - // Call the corresponding kernel - backwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, grad_input, grad, - output, scale, microbatches, query_seq_len, key_seq_len); + output_t *grad_input, const input_t *grad, const input_t *output, const acc_t scale, + int query_seq_len, int key_seq_len, int batches, int attn_heads, cudaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + + if (key_seq_len == 0) { + return; + } + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_WIDTH constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_ROWS constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = threads_per_block / warp_width; + int microbatches_per_block = warps_per_block * microbatches_per_warp; + int microbatches = batches * attn_heads * query_seq_len; + int blocks = DIVUP(microbatches, microbatches_per_block); + + dim3 block_size(warp_width, warps_per_block); + dim3 grid_size(blocks); + + // create an array of pointers to functions + using BackwardFuncType = typename FunctionWrapper::BackwardType; + static std::array backwardFunctionArray; + static bool is_initialized = false; + if (!is_initialized) { + CompileTimeLoopBackward::populate( + &backwardFunctionArray); + is_initialized = true; + } + // Call the corresponding kernel + backwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, grad_input, grad, output, + scale, microbatches, query_seq_len, key_seq_len); } - -void scaled_aligned_causal_masked_softmax_forward( - const Tensor &input, - Tensor *softmax_results, - float scale_factor, - cudaStream_t stream) { - - const int batches = input.data.shape[0]; - const int attn_heads = input.data.shape[1]; - const int query_seq_len = input.data.shape[2]; - const int key_seq_len = input.data.shape[3]; - - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type, - dispatch_scaled_aligned_causal_masked_softmax_forward( - reinterpret_cast(softmax_results->data.dptr), - reinterpret_cast(input.data.dptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - stream);); +void scaled_aligned_causal_masked_softmax_forward(const Tensor &input, Tensor *softmax_results, + float scale_factor, cudaStream_t stream) { + const int batches = input.data.shape[0]; + const int attn_heads = input.data.shape[1]; + const int query_seq_len = input.data.shape[2]; + const int key_seq_len = input.data.shape[3]; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.data.dtype, softmax_type, + dispatch_scaled_aligned_causal_masked_softmax_forward( + reinterpret_cast(softmax_results->data.dptr), + reinterpret_cast(input.data.dptr), scale_factor, query_seq_len, + key_seq_len, batches, attn_heads, stream);); } -void scaled_aligned_causal_masked_softmax_backward( - Tensor output_grads, - const Tensor incoming_grads, - const Tensor softmax_results, - float scale_factor, - cudaStream_t stream) { - - // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.data.shape[0]; - const int attn_heads = output_grads.data.shape[1]; - const int query_seq_len = output_grads.data.shape[2]; - const int key_seq_len = output_grads.data.shape[3]; - - // Softmax Grad - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type, - dispatch_scaled_aligned_causal_masked_softmax_backward( - reinterpret_cast(output_grads.data.dptr), - reinterpret_cast(incoming_grads.data.dptr), - reinterpret_cast(softmax_results.data.dptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - stream);); +void scaled_aligned_causal_masked_softmax_backward(Tensor output_grads, const Tensor incoming_grads, + const Tensor softmax_results, float scale_factor, + cudaStream_t stream) { + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.data.shape[0]; + const int attn_heads = output_grads.data.shape[1]; + const int query_seq_len = output_grads.data.shape[2]; + const int key_seq_len = output_grads.data.shape[3]; + + // Softmax Grad + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output_grads.data.dtype, softmax_type, + dispatch_scaled_aligned_causal_masked_softmax_backward( + reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(incoming_grads.data.dptr), + reinterpret_cast(softmax_results.data.dptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, stream);); } } // end namespace transformer_engine - -void nvte_scaled_aligned_causal_masked_softmax_forward( - const NVTETensor input, - NVTETensor softmax_results, - float scale_factor, - cudaStream_t stream -) { - NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward); - using namespace transformer_engine; - scaled_aligned_causal_masked_softmax_forward( - *reinterpret_cast(input), - reinterpret_cast(softmax_results), - scale_factor, - stream); +void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input, + NVTETensor softmax_results, + float scale_factor, cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward); + using namespace transformer_engine; + scaled_aligned_causal_masked_softmax_forward(*reinterpret_cast(input), + reinterpret_cast(softmax_results), + scale_factor, stream); } - -void nvte_scaled_aligned_causal_masked_softmax_backward( - const NVTETensor incoming_grads, - const NVTETensor softmax_results, - NVTETensor output_grads, - float scale_factor, - cudaStream_t stream -) { - NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward); - using namespace transformer_engine; - scaled_aligned_causal_masked_softmax_backward( - *reinterpret_cast(output_grads), - *reinterpret_cast(incoming_grads), - *reinterpret_cast(softmax_results), - scale_factor, - stream); +void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward); + using namespace transformer_engine; + scaled_aligned_causal_masked_softmax_backward( + *reinterpret_cast(output_grads), *reinterpret_cast(incoming_grads), + *reinterpret_cast(softmax_results), scale_factor, stream); } diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index 7a7194878e3ab26439a2fd2b2cf28e2b81231857..08fd32af9cca2d5c820c4c5bcb893f2cb44b19b9 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -5,1169 +5,846 @@ ************************************************************************/ #include -#include - -#include -#include - #include #include #include #include - +#include #include + +#include +#include + #include "../common.h" -#include "../utils.cuh" #include "../util/logging.h" - +#include "../utils.cuh" namespace transformer_engine { - template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(bf16 *dst, - const bf16 *src) { +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { *dst = *src; } - template <> -__device__ __inline__ void copy_vector(bf16 *dst, - const bf16 *src) { - *((float2*) dst) = *((float2*) src); // NOLINT(*) +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *((float2 *)dst) = *((float2 *)src); // NOLINT(*) } template <> -__device__ __inline__ void copy_vector(half *dst, - const half *src) { +__device__ __inline__ void copy_vector(half *dst, const half *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(half *dst, - const half *src) { - *((float2*) dst) = *((float2*) src); // NOLINT(*) +__device__ __inline__ void copy_vector(half *dst, const half *src) { + *((float2 *)dst) = *((float2 *)src); // NOLINT(*) } template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2*) dst) = *((half2*) src); // NOLINT(*) +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); // NOLINT(*) } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } }; template __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } - /* * Extended softmax (from native aten pytorch) with following additional features * 1) input scaling */ template -__global__ void scaled_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ? - next_power_of_two : THREADS_PER_WARP; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - size_t first_batch = (blockDim.y * (blockIdx.x + gridDim.x * - (blockIdx.y + gridDim.y * blockIdx.z)) - + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - size_t thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - src += thread_offset; - dst += thread_offset; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } - } else { +__global__ void scaled_softmax_warp_forward(output_t *dst, const input_t *src, const acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + size_t first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + size_t thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + src += thread_offset; + dst += thread_offset; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } + } else { #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; + // compute max_value + acc_t max_value[WARP_BATCH]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); + } + warp_reduce(max_value); - acc_t sum[WARP_BATCH] { 0.0f }; + acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { + for (int i = 0; i < WARP_BATCH; ++i) { #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; } - warp_reduce(sum); + } + warp_reduce(sum); - // store result - output_t out[ELEMENTS_PER_LDG_STG]; + // store result + output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector(dst + i * element_count - + it * WARP_SIZE, out); - } else { - break; - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } } + } } - /* * Extended softmax (from native aten pytorch) with following additional features * 1) input scaling * 2) Explicit masking */ template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ? - next_power_of_two : THREADS_PER_WARP; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - size_t first_batch = (blockDim.y * (blockIdx.x + gridDim.x * - (blockIdx.y + gridDim.y * blockIdx.z)) - + threadIdx.y) * WARP_BATCH; - size_t pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) - * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } +__global__ void scaled_masked_softmax_warp_forward(output_t *dst, const input_t *src, + const uint8_t *mask, const acc_t scale, + int micro_batch_size, int element_count, + int pad_batches) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + size_t first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + size_t pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - size_t thread_offset_src_dst = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - size_t thread_offset_mask = pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - src += thread_offset_src_dst; - dst += thread_offset_src_dst; - mask += thread_offset_mask; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + size_t thread_offset_src_dst = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + size_t thread_offset_mask = pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + src += thread_offset_src_dst; + dst += thread_offset_src_dst; + mask += thread_offset_mask; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; + // compute max_value + acc_t max_value[WARP_BATCH]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); + } + warp_reduce(max_value); - // compute scale value to account for full mask - acc_t scale_value[WARP_BATCH]; + // compute scale value to account for full mask + acc_t scale_value[WARP_BATCH]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; - } + for (int i = 0; i < WARP_BATCH; ++i) { + scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; + } - acc_t sum[WARP_BATCH] { 0.0f }; + acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { + for (int i = 0; i < WARP_BATCH; ++i) { #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; } - warp_reduce(sum); + } + warp_reduce(sum); - // store result - output_t out[ELEMENTS_PER_LDG_STG]; + // store result + output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] * scale_value[i] / sum[i]; - } - copy_vector(dst + i * element_count - + it * WARP_SIZE, out); - } else { - break; - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] * scale_value[i] / sum[i]; } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } } + } } template -__global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - const input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ? - next_power_of_two : THREADS_PER_WARP; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - size_t first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - size_t thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count - + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count - + it * WARP_SIZE); - -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } +__global__ void scaled_masked_softmax_warp_backward(output_t *gradInput, const input_t *grad, + const input_t *output, acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + size_t first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + size_t thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, + grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, + output + i * element_count + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * - output_reg[i][it + element]; - } - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; } + } } + } - acc_t sum[WARP_BATCH]; + acc_t sum[WARP_BATCH]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; } - warp_reduce(sum); + } + warp_reduce(sum); - // store result + // store result #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count - + it * WARP_SIZE, out); - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, + out); + } } + } } - -template -void dispatch_scaled_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - cudaStream_t stream) { - NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr - // value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two - : THREADS_PER_WARP; - - // This value must match the WARP_BATCH constexpr - // value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - NVTE_CHECK(query_seq_len%batches_per_block == 0, "Unsupported shape."); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 1: // 2 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 2: // 4 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 3: // 8 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 4: // 16 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 5: // 32 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 6: // 64 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 7: // 128 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 8: // 256 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 9: // 512 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 10: // 1024 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 11: // 2048 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 12: // 4096 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 13: // 8192 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - case 14: // 16384 - scaled_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - key_seq_len); - break; - default: - break; - } +template +void dispatch_scaled_softmax_forward(output_t *dst, const input_t *src, const input_t scale, + int query_seq_len, int key_seq_len, int batches, + int attn_heads, cudaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + NVTE_CHECK(query_seq_len % batches_per_block == 0, "Unsupported shape."); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 12: // 4096 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 13: // 8192 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 14: // 16384 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + default: + break; } + } } -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches, - cudaStream_t stream) { - NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr - // value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two - : THREADS_PER_WARP; - - // This value must match the WARP_BATCH constexpr - // value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - NVTE_CHECK(query_seq_len%batches_per_block == 0, "Unsupported shape."); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 12: // 4096 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 13: // 8192 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - case 14: // 16384 - scaled_masked_softmax_warp_forward - <<>>(dst, - src, - mask, - scale, - batch_count, - key_seq_len, - pad_batches); - break; - default: - break; - } +template +void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, const uint8_t *mask, + const input_t scale, int query_seq_len, int key_seq_len, + int batches, int attn_heads, int pad_batches, + cudaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + NVTE_CHECK(query_seq_len % batches_per_block == 0, "Unsupported shape."); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 12: // 4096 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 13: // 8192 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 14: // 16384 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + default: + break; } + } } -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - const input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - cudaStream_t stream) { - NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr - // value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two - : THREADS_PER_WARP; - - // This value must match the WARP_BATCH constexpr - // value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 12: // 4096 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 13: // 8192 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - case 14: // 16384 - scaled_masked_softmax_warp_backward - <<>>(grad_input, - grad, - output, - scale, - batch_count, - key_seq_len); - break; - default: - break; - } +template +void dispatch_scaled_masked_softmax_backward(output_t *grad_input, const input_t *grad, + const input_t *output, const acc_t scale, + int query_seq_len, int key_seq_len, int batches, + int attn_heads, cudaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 12: // 4096 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 13: // 8192 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 14: // 16384 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + default: + break; } + } } - -void scaled_softmax_forward( - const Tensor &input, - Tensor *softmax_results, - float scale_factor, - cudaStream_t stream) { - - const int batches = input.data.shape[0]; - const int attn_heads = input.data.shape[1]; - const int query_seq_len = input.data.shape[2]; - const int key_seq_len = input.data.shape[3]; - - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type, - dispatch_scaled_softmax_forward( - reinterpret_cast(softmax_results->data.dptr), - reinterpret_cast(input.data.dptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - stream);); +void scaled_softmax_forward(const Tensor &input, Tensor *softmax_results, float scale_factor, + cudaStream_t stream) { + const int batches = input.data.shape[0]; + const int attn_heads = input.data.shape[1]; + const int query_seq_len = input.data.shape[2]; + const int key_seq_len = input.data.shape[3]; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.data.dtype, softmax_type, + dispatch_scaled_softmax_forward( + reinterpret_cast(softmax_results->data.dptr), + reinterpret_cast(input.data.dptr), scale_factor, query_seq_len, + key_seq_len, batches, attn_heads, stream);); } -void scaled_softmax_backward( - Tensor output_grads, - const Tensor incoming_grads, - const Tensor softmax_results, - float scale_factor, - cudaStream_t stream) { - - // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.data.shape[0]; - const int attn_heads = output_grads.data.shape[1]; - const int query_seq_len = output_grads.data.shape[2]; - const int key_seq_len = output_grads.data.shape[3]; - - // Softmax Grad - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type, - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads.data.dptr), - reinterpret_cast(incoming_grads.data.dptr), - reinterpret_cast(softmax_results.data.dptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - stream);); +void scaled_softmax_backward(Tensor output_grads, const Tensor incoming_grads, + const Tensor softmax_results, float scale_factor, + cudaStream_t stream) { + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.data.shape[0]; + const int attn_heads = output_grads.data.shape[1]; + const int query_seq_len = output_grads.data.shape[2]; + const int key_seq_len = output_grads.data.shape[3]; + + // Softmax Grad + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output_grads.data.dtype, softmax_type, + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(incoming_grads.data.dptr), + reinterpret_cast(softmax_results.data.dptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, stream);); } - -void scaled_masked_softmax_forward( - const Tensor input, - const Tensor mask, - Tensor *softmax_results, - float scale_factor, - cudaStream_t stream) { - - const int batches = input.data.shape[0]; - const int pad_batches = mask.data.shape[0]; - const int attn_heads = input.data.shape[1]; - const int query_seq_len = input.data.shape[2]; - const int key_seq_len = input.data.shape[3]; - - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type, - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results->data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(mask.data.dptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches, - stream);); +void scaled_masked_softmax_forward(const Tensor input, const Tensor mask, Tensor *softmax_results, + float scale_factor, cudaStream_t stream) { + const int batches = input.data.shape[0]; + const int pad_batches = mask.data.shape[0]; + const int attn_heads = input.data.shape[1]; + const int query_seq_len = input.data.shape[2]; + const int key_seq_len = input.data.shape[3]; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.data.dtype, softmax_type, + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results->data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(mask.data.dptr), scale_factor, query_seq_len, + key_seq_len, batches, attn_heads, pad_batches, stream);); } - -void scaled_masked_softmax_backward( - Tensor output_grads, - const Tensor incoming_grads, - const Tensor softmax_results, - float scale_factor, - cudaStream_t stream -) { - // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.data.shape[0]; - const int attn_heads = output_grads.data.shape[1]; - const int query_seq_len = output_grads.data.shape[2]; - const int key_seq_len = output_grads.data.shape[3]; - - // Softmax Grad - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type, - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads.data.dptr), - reinterpret_cast(incoming_grads.data.dptr), - reinterpret_cast(softmax_results.data.dptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - stream);); +void scaled_masked_softmax_backward(Tensor output_grads, const Tensor incoming_grads, + const Tensor softmax_results, float scale_factor, + cudaStream_t stream) { + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.data.shape[0]; + const int attn_heads = output_grads.data.shape[1]; + const int query_seq_len = output_grads.data.shape[2]; + const int key_seq_len = output_grads.data.shape[3]; + + // Softmax Grad + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output_grads.data.dtype, softmax_type, + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(incoming_grads.data.dptr), + reinterpret_cast(softmax_results.data.dptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, stream);); } - } // end namespace transformer_engine - -void nvte_scaled_softmax_forward( - const NVTETensor input, - NVTETensor softmax_results, - float scale_factor, - cudaStream_t stream -) { +void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_results, + float scale_factor, cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_softmax_forward); using namespace transformer_engine; - scaled_softmax_forward( - *reinterpret_cast(input), - reinterpret_cast(softmax_results), - scale_factor, - stream); + scaled_softmax_forward(*reinterpret_cast(input), + reinterpret_cast(softmax_results), scale_factor, stream); } - -void nvte_scaled_softmax_backward( - const NVTETensor incoming_grads, - const NVTETensor softmax_results, - NVTETensor output_grads, - float scale_factor, - cudaStream_t stream -) { +void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_softmax_backward); using namespace transformer_engine; - scaled_softmax_backward( - *reinterpret_cast(output_grads), - *reinterpret_cast(incoming_grads), - *reinterpret_cast(softmax_results), - scale_factor, - stream); + scaled_softmax_backward(*reinterpret_cast(output_grads), + *reinterpret_cast(incoming_grads), + *reinterpret_cast(softmax_results), scale_factor, stream); } - -void nvte_scaled_masked_softmax_forward( - const NVTETensor input, - const NVTETensor mask, - NVTETensor softmax_results, - float scale_factor, - cudaStream_t stream -) { +void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask, + NVTETensor softmax_results, float scale_factor, + cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_masked_softmax_forward); using namespace transformer_engine; - scaled_masked_softmax_forward( - *reinterpret_cast(input), - *reinterpret_cast(mask), - reinterpret_cast(softmax_results), - scale_factor, - stream); + scaled_masked_softmax_forward(*reinterpret_cast(input), + *reinterpret_cast(mask), + reinterpret_cast(softmax_results), scale_factor, stream); } - -void nvte_scaled_masked_softmax_backward( - const NVTETensor incoming_grads, - const NVTETensor softmax_results, - NVTETensor output_grads, - float scale_factor, - cudaStream_t stream -) { +void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, NVTETensor output_grads, + float scale_factor, cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_masked_softmax_backward); using namespace transformer_engine; scaled_masked_softmax_backward( - *reinterpret_cast(output_grads), - *reinterpret_cast(incoming_grads), - *reinterpret_cast(softmax_results), - scale_factor, - stream); + *reinterpret_cast(output_grads), *reinterpret_cast(incoming_grads), + *reinterpret_cast(softmax_results), scale_factor, stream); } diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index 7b1ae14b1d944240fcab2cee9f27713874c2fabe..8571887ee64d8586b1f3bf538477d73ff515359a 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -5,21 +5,19 @@ ************************************************************************/ #include -#include - -#include -#include - #include #include #include #include - +#include #include + +#include +#include + #include "../common.h" -#include "../utils.cuh" #include "../util/logging.h" - +#include "../utils.cuh" namespace transformer_engine { @@ -27,39 +25,33 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(bf16 *dst, - const bf16 *src) { +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(bf16 *dst, - const bf16 *src) { - *((float2*) dst) = *((float2*) src); // NOLINT(*) +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *((float2 *)dst) = *((float2 *)src); // NOLINT(*) } template <> -__device__ __inline__ void copy_vector(fp16 *dst, - const fp16 *src) { +__device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(fp16 *dst, - const fp16 *src) { - *((float2*) dst) = *((float2*) src); // NOLINT(*) +__device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { + *((float2 *)dst) = *((float2 *)src); // NOLINT(*) } template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2*) dst) = *((half2*) src); // NOLINT(*) +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); // NOLINT(*) } template @@ -72,52 +64,50 @@ __device__ __inline__ void copy_zero_vector(bf16 *dst) { template <> __device__ __inline__ void copy_zero_vector(bf16 *dst) { - *((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*) + *((float2 *)dst) = make_float2(0.0f, 0.0f); // NOLINT(*) } template <> -__device__ __inline__ void copy_zero_vector(fp16 *dst) { *dst = 0.0f; } +__device__ __inline__ void copy_zero_vector(fp16 *dst) { + *dst = 0.0f; +} template <> __device__ __inline__ void copy_zero_vector(fp16 *dst) { - *((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*) + *((float2 *)dst) = make_float2(0.0f, 0.0f); // NOLINT(*) } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } }; template __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* @@ -126,675 +116,500 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { * 2) Implicit time (diagonal masking) */ template -__global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ? - next_power_of_two : THREADS_PER_WARP; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - size_t first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH - + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - size_t thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - src += thread_offset; - dst += thread_offset; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; +__global__ void scaled_upper_triang_masked_softmax_warp_forward(output_t *dst, const input_t *src, + const acc_t scale, + int micro_batch_size, int stride, + int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + size_t first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + size_t thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + src += thread_offset; + dst += thread_offset; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_data, - src + i*element_count*stride - + it*WARP_SIZE); + if (element_index < batch_element_count) { + copy_vector( + temp_data, src + i * element_count * stride + it * WARP_SIZE); #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; + // compute max_value + acc_t max_value[WARP_BATCH]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); + } + warp_reduce(max_value); - acc_t sum[WARP_BATCH] { 0.0f }; + acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { + for (int i = 0; i < WARP_BATCH; ++i) { #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } } - warp_reduce(sum); + } + warp_reduce(sum); - // store result - output_t out[ELEMENTS_PER_LDG_STG]; + // store result + output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < local_seq) { + if (element_index < local_seq) { #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0.0f; - } - } - copy_vector(dst + i * element_count * stride - + it * WARP_SIZE, - out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride - + it * WARP_SIZE); - } else { - break; - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0.0f; + } } + copy_vector( + dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + + it * WARP_SIZE); + } else { + break; + } } + } } template -__global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - const input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ? - next_power_of_two : THREADS_PER_WARP; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - size_t first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH - + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - size_t thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; +__global__ void scaled_upper_triang_masked_softmax_warp_backward(output_t *gradInput, + const input_t *grad, + const input_t *output, acc_t scale, + int micro_batch_size, int stride, + int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + size_t first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + size_t thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, - grad + i * element_count * stride - + it * WARP_SIZE); - copy_vector(temp_output, - output + i * element_count * stride - + it * WARP_SIZE); + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count * stride + it * WARP_SIZE); #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * - output_reg[i][it + element]; - } - } - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } } + } } + } - acc_t sum[WARP_BATCH]; + acc_t sum[WARP_BATCH]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; } - warp_reduce(sum); + } + warp_reduce(sum); - // store result + // store result #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride - + it * WARP_SIZE, out); - } + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector( + gradInput + i * element_count * stride + it * WARP_SIZE, out); + } } + } } - -template -void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches, - cudaStream_t stream) { - NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape."); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr - // value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two - : THREADS_PER_WARP; - - // This value must match the WARP_BATCH constexpr - // value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - NVTE_CHECK(attn_batches % batches_per_block == 0, "Unsupported shape."); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 14: // 16384 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, - src, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - default: - break; - } +template +void dispatch_scaled_upper_triang_masked_softmax_forward(output_t *dst, const input_t *src, + const input_t scale, int softmax_elements, + int softmax_elements_stride, + int attn_batches, cudaStream_t stream) { + NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape."); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + NVTE_CHECK(attn_batches % batches_per_block == 0, "Unsupported shape."); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 14: // 16384 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + default: + break; } + } } -template -void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - const input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches, - cudaStream_t stream) { - NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape."); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr - // value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two - : THREADS_PER_WARP; - - // This value must match the WARP_BATCH constexpr - // value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - NVTE_CHECK(attn_batches % batches_per_block == 0, "Unsupported shape."); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - case 14: // 16384 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, - grad, output, - scale, - batch_count, - softmax_elements_stride, - softmax_elements); - break; - default: - break; - } +template +void dispatch_scaled_upper_triang_masked_softmax_backward(output_t *grad_input, const input_t *grad, + const input_t *output, const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches, cudaStream_t stream) { + NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape."); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + NVTE_CHECK(attn_batches % batches_per_block == 0, "Unsupported shape."); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 14: // 16384 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; } + } } - -void scaled_upper_triang_masked_softmax_forward( - const Tensor input, - Tensor *softmax_results, - float scale_factor, - cudaStream_t stream) { - - const int attn_batches = input.data.shape[0]; - const int seq_len = input.data.shape[1]; - - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type, - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results->data.dptr), - reinterpret_cast(input.data.dptr), - scale_factor, - seq_len, - seq_len, - attn_batches, - stream);); +void scaled_upper_triang_masked_softmax_forward(const Tensor input, Tensor *softmax_results, + float scale_factor, cudaStream_t stream) { + const int attn_batches = input.data.shape[0]; + const int seq_len = input.data.shape[1]; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.data.dtype, softmax_type, + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results->data.dptr), + reinterpret_cast(input.data.dptr), scale_factor, seq_len, seq_len, + attn_batches, stream);); } - -void scaled_upper_triang_masked_softmax_backward( - Tensor output_grads, - const Tensor incoming_grads, - const Tensor softmax_results, - float scale_factor, - cudaStream_t stream) { - - const int attn_batches = output_grads.data.shape[0]; - const int seq_len = output_grads.data.shape[1]; - - // Softmax Grad - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type, - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads.data.dptr), - reinterpret_cast(incoming_grads.data.dptr), - reinterpret_cast(softmax_results.data.dptr), - scale_factor, - seq_len, - seq_len, - attn_batches, - stream);); +void scaled_upper_triang_masked_softmax_backward(Tensor output_grads, const Tensor incoming_grads, + const Tensor softmax_results, float scale_factor, + cudaStream_t stream) { + const int attn_batches = output_grads.data.shape[0]; + const int seq_len = output_grads.data.shape[1]; + + // Softmax Grad + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output_grads.data.dtype, softmax_type, + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(incoming_grads.data.dptr), + reinterpret_cast(softmax_results.data.dptr), scale_factor, seq_len, + seq_len, attn_batches, stream);); } } // end namespace transformer_engine - -void nvte_scaled_upper_triang_masked_softmax_forward( - const NVTETensor input, - NVTETensor softmax_results, - float scale_factor, - cudaStream_t stream -) { - using namespace transformer_engine; - scaled_upper_triang_masked_softmax_forward( - *reinterpret_cast(input), - reinterpret_cast(softmax_results), - scale_factor, - stream); +void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input, + NVTETensor softmax_results, float scale_factor, + cudaStream_t stream) { + using namespace transformer_engine; + scaled_upper_triang_masked_softmax_forward(*reinterpret_cast(input), + reinterpret_cast(softmax_results), + scale_factor, stream); } - -void nvte_scaled_upper_triang_masked_softmax_backward( - const NVTETensor incoming_grads, - const NVTETensor softmax_results, - NVTETensor output_grads, - float scale_factor, - cudaStream_t stream -) { - using namespace transformer_engine; - scaled_upper_triang_masked_softmax_backward( - *reinterpret_cast(output_grads), - *reinterpret_cast(incoming_grads), - *reinterpret_cast(softmax_results), - scale_factor, - stream); +void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + cudaStream_t stream) { + using namespace transformer_engine; + scaled_upper_triang_masked_softmax_backward( + *reinterpret_cast(output_grads), *reinterpret_cast(incoming_grads), + *reinterpret_cast(softmax_results), scale_factor, stream); } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index a4c65661dcd1669c7ea66f0a70fce222980456c8..8a2df1b944da4b6db546069f231d299b57c02b28 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -4,14 +4,14 @@ * See LICENSE for license information. ************************************************************************/ -#include - #include #include #include +#include +#include + #include -#include #include "../common.h" #include "../util/logging.h" @@ -38,7 +38,7 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; - for (; ; alignment /= 2) { + for (;; alignment /= 2) { if (address % alignment == 0) { return alignment; } @@ -49,27 +49,12 @@ uint32_t _getAlignment(uintptr_t address) { namespace transformer_engine { -void cublas_gemm(const Tensor *inputA, - const Tensor *inputB, - Tensor *outputD, - const Tensor *inputBias, - Tensor *outputPreGelu, - int m, int n, int k, - int lda, int ldb, int ldd, - cublasOperation_t transa, - cublasOperation_t transb, - bool grad, - void* workspace, - size_t workspaceSize, - bool accumulate, - bool use_split_accumulator, - int math_sm_count, - int m_split, - int n_split, - bool gemm_producer, - const Tensor *inputCounter, - cudaStream_t stream -) { +void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, + const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, + int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, + void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, + int math_sm_count, int m_split, int n_split, bool gemm_producer, + const Tensor *inputCounter, cudaStream_t stream) { void *A = inputA->data.dptr; void *A_scale_inverse = inputA->scale_inv.dptr; void *B = inputB->data.dptr; @@ -86,8 +71,7 @@ void cublas_gemm(const Tensor *inputA, counter = inputCounter->data.dptr; } const bool gelu = pre_gelu_out != nullptr; - const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || - is_fp8_dtype(inputB->data.dtype); + const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); const cudaDataType_t A_type = get_cuda_dtype(inputA->data.dtype); const cudaDataType_t B_type = get_cuda_dtype(inputB->data.dtype); const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype); @@ -103,11 +87,10 @@ void cublas_gemm(const Tensor *inputA, // fp8 + gelu fusion + fp8 aux is unavailable right now. if (use_fp8 && gelu) { NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), - "fp8 Aux output for gemm + gelu fusion not supported!"); + "fp8 Aux output for gemm + gelu fusion not supported!"); } if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(!accumulate, - "Accumulation mode not supported with FP8 GEMM output!"); + NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); } float one = 1.0; @@ -117,14 +100,14 @@ void cublas_gemm(const Tensor *inputA, cublasLtHandle_t handle; NVTE_CHECK_CUBLAS(cublasLtCreate(&handle)); - cublasLtMatmulDesc_t operationDesc = nullptr; - cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; + cublasLtMatmulDesc_t operationDesc = nullptr; + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; cublasLtMatmulPreference_t preference = nullptr; - int returnedResults = 0; + int returnedResults = 0; cublasLtMatmulHeuristicResult_t heuristicResult = {}; cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - int64_t ld_gelumat = (int64_t) ldd; + int64_t ld_gelumat = (int64_t)ldd; // Use TF32 only for pure FP32 GEMM. cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F; @@ -133,14 +116,10 @@ void cublas_gemm(const Tensor *inputA, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, - transa == CUBLAS_OP_N ? m : k, - transa == CUBLAS_OP_N ? k : m, - lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, - transb == CUBLAS_OP_N ? k : n, - transb == CUBLAS_OP_N ? n : k, - ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, transa == CUBLAS_OP_N ? m : k, + transa == CUBLAS_OP_N ? k : m, lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, transb == CUBLAS_OP_N ? k : n, + transb == CUBLAS_OP_N ? n : k, ldb)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); @@ -150,50 +129,40 @@ void cublas_gemm(const Tensor *inputA, &transb, sizeof(transb))); // Set math SM count if (math_sm_count != 0) { - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, - &math_sm_count, sizeof(math_sm_count))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, + &math_sm_count, sizeof(math_sm_count))); } - // set fp8 attributes -- input and output types should already be set to fp8 as appropriate // Note: gelu fusion isn't available right now, and we don't need // amax(D) either (next op is high precision). if (use_fp8) { // Split accumulator. const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_FAST_ACCUM, - &fastAccuMode, - sizeof(fastAccuMode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &fastAccuMode, sizeof(fastAccuMode))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, - &A_scale_inverse, - sizeof(A_scale_inverse))); + &A_scale_inverse, sizeof(A_scale_inverse))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, - &B_scale_inverse, - sizeof(B_scale_inverse))); + &B_scale_inverse, sizeof(B_scale_inverse))); if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output C = nullptr; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, - &D_scale, - sizeof(D_scale))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, - &D_amax, - sizeof(D_amax))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); // For FP8 output, cuBLAS requires C_type to be same as bias_type NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd)); } else { NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd)); } if (bias) { - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, - &bias_type, sizeof(bias_type))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type))); } } else { NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd)); @@ -205,19 +174,16 @@ void cublas_gemm(const Tensor *inputA, } else { epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS; } - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_ptr, sizeof(bias_ptr))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - &pre_gelu_out, sizeof(pre_gelu_out))); + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - &ld_gelumat, sizeof(ld_gelumat))); + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &pre_gelu_out, sizeof(pre_gelu_out))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat))); const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, - &aux_type, sizeof(aux_type))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type))); } else if (bias) { if (grad) { // grad output is always input B @@ -225,72 +191,64 @@ void cublas_gemm(const Tensor *inputA, } else { epilogue = CUBLASLT_EPILOGUE_BIAS; } - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_ptr, sizeof(bias_ptr))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); } else if (gelu) { if (grad) { epilogue = CUBLASLT_EPILOGUE_DGELU; } else { epilogue = CUBLASLT_EPILOGUE_GELU_AUX; } - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - &pre_gelu_out, sizeof(pre_gelu_out))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - &ld_gelumat, sizeof(ld_gelumat))); + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &pre_gelu_out, sizeof(pre_gelu_out))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat))); } - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_EPILOGUE, + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 if (counter != nullptr) { - if (m_split == 0) m_split=1; - if (n_split == 0) n_split=1; + if (m_split == 0) m_split = 1; + if (n_split == 0) n_split = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, - &m_split, sizeof(m_split))); + operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, &m_split, + sizeof(m_split))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, - &n_split, sizeof(n_split))); + operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, &n_split, + sizeof(n_split))); if (gemm_producer) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, - &counter, sizeof(counter))); + operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, &counter, + sizeof(counter))); } else { NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, - &counter, sizeof(counter))); + operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter, + sizeof(counter))); } } #endif NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspaceSize, sizeof(workspaceSize))); + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); const auto A_alignment = _getAlignment(reinterpret_cast(A)); const auto B_alignment = _getAlignment(reinterpret_cast(B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, - &A_alignment, sizeof(A_alignment))); + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, - &B_alignment, sizeof(B_alignment))); + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &B_alignment, sizeof(B_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, - &C_alignment, sizeof(C_alignment))); + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, - &D_alignment, sizeof(D_alignment))); + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); - const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, - Ddesc, preference, 1, &heuristicResult, - &returnedResults); + const auto status = + cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, + 1, &heuristicResult, &returnedResults); NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); @@ -298,22 +256,16 @@ void cublas_gemm(const Tensor *inputA, if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, - operationDesc, - static_cast(&one), /* alpha */ - A, /* A */ - Adesc, - B, /* B */ - Bdesc, - static_cast(&beta), /* beta */ - C, /* C */ - Cdesc, - D, /* D */ - Ddesc, - &heuristicResult.algo, /* algo */ - workspace, /* workspace */ - workspaceSize, - stream)); /* stream */ + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, + static_cast(&one), /* alpha */ + A, /* A */ + Adesc, B, /* B */ + Bdesc, static_cast(&beta), /* beta */ + C, /* C */ + Cdesc, D, /* D */ + Ddesc, &heuristicResult.algo, /* algo */ + workspace, /* workspace */ + workspaceSize, stream)); /* stream */ NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc)); @@ -325,27 +277,18 @@ void cublas_gemm(const Tensor *inputA, } // namespace transformer_engine -void nvte_cublas_gemm(const NVTETensor A, - const NVTETensor B, - NVTETensor D, - const NVTETensor bias, - NVTETensor pre_gelu_out, - bool transa, - bool transb, - bool grad, - NVTETensor workspace, - bool accumulate, - bool use_split_accumulator, - int math_sm_count, - cudaStream_t stream) { +void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, + NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, + NVTETensor workspace, bool accumulate, bool use_split_accumulator, + int math_sm_count, cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; - const Tensor *inputA = reinterpret_cast(A); - const Tensor *inputB = reinterpret_cast(B); - Tensor *outputD = reinterpret_cast(D); - const Tensor *biasTensor = reinterpret_cast(bias); - Tensor *outputGelu = reinterpret_cast(pre_gelu_out); - Tensor *wspace = reinterpret_cast(workspace); + const Tensor *inputA = reinterpret_cast(A); + const Tensor *inputB = reinterpret_cast(B); + Tensor *outputD = reinterpret_cast(D); + const Tensor *biasTensor = reinterpret_cast(bias); + Tensor *outputGelu = reinterpret_cast(pre_gelu_out); + Tensor *wspace = reinterpret_cast(workspace); const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; @@ -367,42 +310,17 @@ void nvte_cublas_gemm(const NVTETensor A, NVTE_ERROR("TT layout not allowed."); } - cublas_gemm(inputA, - inputB, - outputD, - biasTensor, - outputGelu, - m, n, k, - lda, ldb, ldd, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, - (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, - grad, wspace->data.dptr, - wspace->data.shape[0], - accumulate, use_split_accumulator, - math_sm_count, - 0, - 0, - false, - nullptr, - stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, + (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, + wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, + math_sm_count, 0, 0, false, nullptr, stream); } -void nvte_cublas_atomic_gemm(const NVTETensor A, - const NVTETensor B, - NVTETensor D, - const NVTETensor bias, - NVTETensor pre_gelu_out, - bool transa, - bool transb, - bool grad, - NVTETensor workspace, - bool accumulate, - bool use_split_accumulator, - int math_sm_count, - int m_split, - int n_split, - bool gemm_producer, - const NVTETensor counter, +void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, + const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, + bool transb, bool grad, NVTETensor workspace, bool accumulate, + bool use_split_accumulator, int math_sm_count, int m_split, + int n_split, bool gemm_producer, const NVTETensor counter, cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_atomic_gemm); @@ -412,13 +330,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm."); using namespace transformer_engine; - const Tensor *inputA = reinterpret_cast(A); - const Tensor *inputB = reinterpret_cast(B); - Tensor *outputD = reinterpret_cast(D); - const Tensor *biasTensor = reinterpret_cast(bias); - Tensor *outputGelu = reinterpret_cast(pre_gelu_out); - const Tensor *inputCounter = reinterpret_cast(counter); - Tensor *wspace = reinterpret_cast(workspace); + const Tensor *inputA = reinterpret_cast(A); + const Tensor *inputB = reinterpret_cast(B); + Tensor *outputD = reinterpret_cast(D); + const Tensor *biasTensor = reinterpret_cast(bias); + Tensor *outputGelu = reinterpret_cast(pre_gelu_out); + const Tensor *inputCounter = reinterpret_cast(counter); + Tensor *wspace = reinterpret_cast(workspace); const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; @@ -440,22 +358,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, NVTE_ERROR("TT layout not allowed."); } - cublas_gemm(inputA, - inputB, - outputD, - biasTensor, - outputGelu, - m, n, k, - lda, ldb, ldd, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, - (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, - grad, wspace->data.dptr, - wspace->data.shape[0], - accumulate, use_split_accumulator, - math_sm_count, - m_split, - n_split, - gemm_producer, - inputCounter, - stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, + (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, + wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, + math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); } diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 0de1c03f12aee2386e1857c39d6ba7916bfbd2fc..656c647fd4761de51c4c640e227e5a8bbe925be0 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -39,25 +39,15 @@ enum class NVTE_Activation_Type { SREGLU, }; -void nvte_gelu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_silu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_relu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_qgelu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_srelu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Compute activation gradient. * @@ -66,31 +56,20 @@ void nvte_srelu(const NVTETensor input, * \param[in,out] output Output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_dgelu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_dsilu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_drelu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_dqgelu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_dsrelu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, - cudaStream_t stream); - +void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream); /*! \brief Compute gated activation of the input. * @@ -99,25 +78,15 @@ void nvte_dsrelu(const NVTETensor grad, * It computes Act(input[N, :H]) x input[N, H:] * \param[in] stream CUDA stream used for the operation. */ -void nvte_geglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_swiglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_reglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_qgeglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_sreglu(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Compute gated activation gradient. * \param[in] grad Incoming gradient of shape [N, H]. @@ -125,30 +94,20 @@ void nvte_sreglu(const NVTETensor input, * \param[in,out] output Outgoing gradient of shape [N, H * 2]. * \param[in] stream CUDA stream used for the operation. */ -void nvte_dgeglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_dswiglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_dreglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, +void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -void nvte_dqgeglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream); -void nvte_dsreglu(const NVTETensor grad, - const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index b6d145c4bf69fb66d7cfd4554d3a2d0f5c183258..32f16922b9db8784d43f849ebeac52a3e8d4ca16 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -23,9 +23,7 @@ extern "C" { * \param[in,out] output Output FP8 tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fp8_quantize(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Cast tensor from FP8. * @@ -33,9 +31,7 @@ void nvte_fp8_quantize(const NVTETensor input, * \param[out] output Output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fp8_dequantize(const NVTETensor input, - NVTETensor output, - cudaStream_t stream); +void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h index f9097679a6f163d9df83b4b68974de51c754bca2..9043162bcbd531a4eade00811b4736fc7989fefb 100644 --- a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -17,15 +17,11 @@ extern "C" { #endif -void nvte_transpose_with_noop(const NVTETensor input, - const NVTETensor noop, - NVTETensor output, +void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream); -void nvte_cast_transpose_with_noop(const NVTETensor input, - const NVTETensor noop, - NVTETensor cast_output, - NVTETensor transposed_output, +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, + NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream); #ifdef __cplusplus diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 38adc66accaa70155b87eb4726e2b1d7d9eb81e5..dac3e0620e2aacd7be1f1bcfb637d75950ebf26a 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -26,91 +26,91 @@ extern "C" { * different lengths in a batch. */ enum NVTE_QKV_Layout { - NVTE_SB3HD = 0, /*!< SB3HD layout */ - NVTE_SBH3D = 1, /*!< SBH3D layout */ - NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ - NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ - NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ - NVTE_BS3HD = 5, /*!< BS3HD layout */ - NVTE_BSH3D = 6, /*!< BSH3D layout */ - NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ - NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ - NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ - NVTE_T3HD = 10, /*!< T3HD layout */ - NVTE_TH3D = 11, /*!< TH3D layout */ - NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ - NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ - NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ + NVTE_SB3HD = 0, /*!< SB3HD layout */ + NVTE_SBH3D = 1, /*!< SBH3D layout */ + NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ + NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ + NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ + NVTE_BS3HD = 5, /*!< BS3HD layout */ + NVTE_BSH3D = 6, /*!< BSH3D layout */ + NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ + NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ + NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ + NVTE_T3HD = 10, /*!< T3HD layout */ + NVTE_TH3D = 11, /*!< TH3D layout */ + NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ + NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ + NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ }; /*! \enum NVTE_QKV_Layout_Group * \brief QKV layout groups */ enum NVTE_QKV_Layout_Group { - /*! 3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD */ - NVTE_3HD = 0, - /*! H3D QKV layouts, i.e. BSH3D, SBH3D, TH3D */ - NVTE_H3D = 1, - /*! HD_2HD QKV layouts, i.e. BSHD_BS2HD, SBHD_SB2HD, THD_T2HD */ - NVTE_HD_2HD = 2, - /*! HD_H2D QKV layouts, i.e. BSHD_BSH2D, SBHD_SBH2D, THD_TH2D */ - NVTE_HD_H2D = 3, - /*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */ - NVTE_HD_HD_HD = 4, + /*! 3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD */ + NVTE_3HD = 0, + /*! H3D QKV layouts, i.e. BSH3D, SBH3D, TH3D */ + NVTE_H3D = 1, + /*! HD_2HD QKV layouts, i.e. BSHD_BS2HD, SBHD_SB2HD, THD_T2HD */ + NVTE_HD_2HD = 2, + /*! HD_H2D QKV layouts, i.e. BSHD_BSH2D, SBHD_SBH2D, THD_TH2D */ + NVTE_HD_H2D = 3, + /*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */ + NVTE_HD_HD_HD = 4, }; /*! \enum NVTE_QKV_Format * \brief QKV formats */ enum NVTE_QKV_Format { - /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */ - NVTE_SBHD = 0, - /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */ - NVTE_BSHD = 1, - /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */ - NVTE_THD = 2, + /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */ + NVTE_SBHD = 0, + /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */ + NVTE_BSHD = 1, + /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */ + NVTE_THD = 2, }; /*! \enum NVTE_Bias_Type * \brief Bias types */ enum NVTE_Bias_Type { - /*! No bias */ - NVTE_NO_BIAS = 0, - /*! Bias before scale */ - NVTE_PRE_SCALE_BIAS = 1, - /*! Bias after scale */ - NVTE_POST_SCALE_BIAS = 2, - /*! ALiBi */ - NVTE_ALIBI = 3, + /*! No bias */ + NVTE_NO_BIAS = 0, + /*! Bias before scale */ + NVTE_PRE_SCALE_BIAS = 1, + /*! Bias after scale */ + NVTE_POST_SCALE_BIAS = 2, + /*! ALiBi */ + NVTE_ALIBI = 3, }; /*! \enum NVTE_Mask_Type * \brief Attention mask types */ enum NVTE_Mask_Type { - /*! No masking */ - NVTE_NO_MASK = 0, - /*! Padding attention mask */ - NVTE_PADDING_MASK = 1, - /*! Causal attention mask */ - NVTE_CAUSAL_MASK = 2, - /*! Padding and causal attention mask */ - NVTE_PADDING_CAUSAL_MASK = 3, + /*! No masking */ + NVTE_NO_MASK = 0, + /*! Padding attention mask */ + NVTE_PADDING_MASK = 1, + /*! Causal attention mask */ + NVTE_CAUSAL_MASK = 2, + /*! Padding and causal attention mask */ + NVTE_PADDING_CAUSAL_MASK = 3, }; /*! \enum NVTE_Fused_Attn_Backend * \brief Fused attention backends */ enum NVTE_Fused_Attn_Backend { - /*! No supported backend */ - NVTE_No_Backend = -1, - /*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */ - NVTE_F16_max512_seqlen = 0, - /*! cuDNN-based FP16/BF16 fused attention for any sequence length */ - NVTE_F16_arbitrary_seqlen = 1, - /*! cuDNN-based FP8 fused attention for <= 512 sequence length */ - NVTE_FP8 = 2, + /*! No supported backend */ + NVTE_No_Backend = -1, + /*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */ + NVTE_F16_max512_seqlen = 0, + /*! cuDNN-based FP16/BF16 fused attention for any sequence length */ + NVTE_F16_arbitrary_seqlen = 1, + /*! cuDNN-based FP8 fused attention for <= 512 sequence length */ + NVTE_FP8 = 2, }; /*! \brief Get QKV layout group for a given QKV layout. @@ -144,15 +144,9 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); * \param[in] head_dim The head dimension of Q, K, V. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, - NVTEDType kv_dtype, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - float dropout, - size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim); + NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim); /*! \brief Compute dot product attention with packed QKV input. * @@ -211,24 +205,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_qkvpacked( - const NVTETensor QKV, - const NVTETensor Bias, - NVTETensor S, - NVTETensor O, - NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - const NVTETensor rng_state, - size_t max_seqlen, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, + NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, + const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -283,25 +268,12 @@ void nvte_fused_attn_fwd_qkvpacked( * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_bwd_qkvpacked( - const NVTETensor QKV, - const NVTETensor O, - const NVTETensor dO, - const NVTETensor S, - NVTETensor dP, - const NVTETensorPack* Aux_CTX_Tensors, - NVTETensor dQKV, - NVTETensor dBias, - const NVTETensor cu_seqlens, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - size_t max_seqlen, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream); + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, size_t max_seqlen, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. * @@ -363,26 +335,16 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, - const NVTETensor KV, - const NVTETensor Bias, - NVTETensor S, - NVTETensor O, - NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream); +void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, + NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, + const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -441,28 +403,13 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_bwd_kvpacked( - const NVTETensor Q, - const NVTETensor KV, - const NVTETensor O, - const NVTETensor dO, - const NVTETensor S, - NVTETensor dP, - const NVTETensorPack* Aux_CTX_Tensors, - NVTETensor dQ, - NVTETensor dKV, - NVTETensor dBias, - const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream); + const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, + const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, + NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -538,27 +485,16 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd( - const NVTETensor Q, - const NVTETensor K, - const NVTETensor V, - const NVTETensor Bias, - NVTETensor S, - NVTETensor O, - NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream); +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -631,31 +567,16 @@ void nvte_fused_attn_fwd( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_bwd( - const NVTETensor Q, - const NVTETensor K, - const NVTETensor V, - const NVTETensor O, - const NVTETensor dO, - const NVTETensor S, - NVTETensor dP, - const NVTETensorPack* Aux_CTX_Tensors, - NVTETensor dQ, - NVTETensor dK, - NVTETensor dV, - NVTETensor dBias, - const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, - const NVTETensor seq_offsets_q, - const NVTETensor seq_offsets_k, - const NVTETensor seq_offsets_v, - const NVTETensor seq_offsets_o, - size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - NVTETensor workspace, - cudaStream_t stream); +void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, + NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index cb712aecff5143172a2297a76d87e9ea4a32140a..b92de88eca5868ef1b9d83f3a56db7a0deabd3b2 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -33,14 +33,11 @@ extern "C" { * \param[in] o_stride_d Stride of the d dimension of output. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, - NVTETensor output, const int s, const int b, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, - cudaStream_t stream); +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream); /*! \brief Compute the backward of the fused rope. * @@ -62,14 +59,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, * \param[in] o_stride_d Stride of the d dimension of input_grads. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_backward(const NVTETensor output_grads, - const NVTETensor freqs, NVTETensor input_grads, - const int s, const int b, const int h, - const int d, const int d2, const int stride_s, - const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, - const int o_stride_d, cudaStream_t stream); +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, + NVTETensor input_grads, const int s, const int b, const int h, + const int d, const int d2, const int stride_s, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, const int o_stride_d, + cudaStream_t stream); /*! \brief Apply rotary positional embedding to the input tensor in thd format. * @@ -90,14 +85,12 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, * \param[in] o_stride_d Stride of the d dimension of output. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_thd_forward(const NVTETensor input, - const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, - const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream); +void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor output, const int max_s, + const int b, const int h, const int d, const int d2, + const int stride_t, const int stride_h, const int stride_d, + const int o_stride_t, const int o_stride_h, const int o_stride_d, + cudaStream_t stream); /*! \brief Compute the backward of the fused rope in thd format. * @@ -118,12 +111,12 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, * \param[in] o_stride_d Stride of the d dimension of input_grads. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_thd_backward( - const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, - const int b, const int h, const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, cudaStream_t stream); +void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor input_grads, const int max_s, + const int b, const int h, const int d, const int d2, + const int stride_t, const int stride_h, const int stride_d, + const int o_stride_t, const int o_stride_h, const int o_stride_d, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 7107aca5fe019763f1a7b3c0d0a73f6f3393fd1f..b9186707dd251b24ca1620f948733b6ff1aa17a2 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -39,20 +39,10 @@ extern "C" { * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream used for the operation. */ -void nvte_cublas_gemm(const NVTETensor A, - const NVTETensor B, - NVTETensor D, - const NVTETensor bias, - NVTETensor pre_gelu_out, - bool transa, - bool transb, - bool grad, - NVTETensor workspace, - bool accumulate, - bool use_split_accumulator, - int math_sm_count, - cudaStream_t stream -); +void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, + NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, + NVTETensor workspace, bool accumulate, bool use_split_accumulator, + int math_sm_count, cudaStream_t stream); /*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters. * @@ -82,27 +72,14 @@ void nvte_cublas_gemm(const NVTETensor A, * \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced. * \param[in] stream CUDA stream used for the operation. */ -void nvte_cublas_atomic_gemm(const NVTETensor A, - const NVTETensor B, - NVTETensor D, - const NVTETensor bias, - NVTETensor pre_gelu_out, - bool transa, - bool transb, - bool grad, - NVTETensor workspace, - bool accumulate, - bool use_split_accumulator, - int math_sm_count, - int m_split, - int n_split, - bool gemm_producer, - const NVTETensor counter, - cudaStream_t stream -); +void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, + const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, + bool transb, bool grad, NVTETensor workspace, bool accumulate, + bool use_split_accumulator, int math_sm_count, int m_split, + int n_split, bool gemm_producer, const NVTETensor counter, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif - #endif // TRANSFORMER_ENGINE_GEMM_H_ diff --git a/transformer_engine/common/include/transformer_engine/layer_norm.h b/transformer_engine/common/include/transformer_engine/layer_norm.h index 6665b7ba5fb2b96fd7f665f34554f31d9e910cb7..3bb4d47f29b94b90e8345621cacd33dc09747913 100644 --- a/transformer_engine/common/include/transformer_engine/layer_norm.h +++ b/transformer_engine/common/include/transformer_engine/layer_norm.h @@ -42,16 +42,9 @@ extern "C" { * \param[out] workspace Workspace tensor. * \param[out] barrier Barrier tensor. */ -void nvte_layernorm_fwd(const NVTETensor x, - const NVTETensor gamma, - const NVTETensor beta, - const float epsilon, - NVTETensor z, - NVTETensor mu, - NVTETensor rsigma, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, +void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); /*! \brief Compute LayerNorm with zero-centered gamma on the input. @@ -79,19 +72,11 @@ void nvte_layernorm_fwd(const NVTETensor x, * \param[out] workspace Workspace tensor. * \param[out] barrier Barrier tensor. */ -void nvte_layernorm1p_fwd(const NVTETensor x, - const NVTETensor gamma, - const NVTETensor beta, - const float epsilon, - NVTETensor z, - NVTETensor mu, - NVTETensor rsigma, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, +void nvte_layernorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); - /*! \brief Compute backward of LayerNorm. * * This function computes the gradient of function: @@ -121,20 +106,14 @@ void nvte_layernorm1p_fwd(const NVTETensor x, * \param[out] workspace Workspace tensor. * \param[out] barrier Barrier tensor. */ -void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, - NVTETensor dgamma, - NVTETensor dbeta, - NVTETensor dgamma_part, - NVTETensor dbeta_part, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, - NVTETensor barrier); +void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size + const NVTETensor x, // BxSxhidden_size + const NVTETensor mu, // BxS, FP32! + const NVTETensor rsigma, // BxS, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part, + NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount, + NVTETensor workspace, NVTETensor barrier); /*! \brief Compute backward of LayerNorm with zero-centered gamma. * @@ -165,20 +144,14 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size * \param[out] workspace Workspace tensor. * \param[out] barrier Barrier tensor. */ -void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, - NVTETensor dgamma, - NVTETensor dbeta, - NVTETensor dgamma_part, - NVTETensor dbeta_part, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, - NVTETensor barrier); +void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size + const NVTETensor x, // BxSxhidden_size + const NVTETensor mu, // BxS, FP32! + const NVTETensor rsigma, // BxS, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, + NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream, + const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 49cc9af9140b0665a780dc3d9bd4d14e53ecbaa6..61b1f231b82feee0c211b7eeda270ea15188f027 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -44,18 +44,11 @@ extern "C" { * \param[in] margin Scaling factor margin. * \param[in] stream CUDA stream. */ -void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, - const NVTETensor scale, - const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, - NVTETensor updated_amax_history, - NVTETensor updated_scale, - NVTETensor updated_scale_inv, - const char* amax_compute_algo, - NVTEDType fp8_dtype, - float margin, - cudaStream_t stream); - +void nvte_delayed_scaling_recipe_amax_and_scale_update( + const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, + const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale, + NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + cudaStream_t stream); /*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. * @@ -85,15 +78,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his * \param[in] stream CUDA stream. */ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( - const NVTETensor amax_reduction_buffer, - std::vector amax_histories, - std::vector scales, - std::vector scale_invs, - const char *amax_compute_algo, - NVTEDType fp8_dtype, - float margin, - cudaStream_t stream); - + const NVTETensor amax_reduction_buffer, std::vector amax_histories, + std::vector scales, std::vector scale_invs, + const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/rmsnorm.h b/transformer_engine/common/include/transformer_engine/rmsnorm.h index 8f5148333131c140ac957293303952996c33ffc3..dc995e3c242f43e011ce2090287be85fa43028b8 100644 --- a/transformer_engine/common/include/transformer_engine/rmsnorm.h +++ b/transformer_engine/common/include/transformer_engine/rmsnorm.h @@ -43,15 +43,9 @@ extern "C" { * \param[out] workspace Workspace tensor. * \param[out] barrier Barrier tensor. */ -void nvte_rmsnorm_fwd(const NVTETensor x, - const NVTETensor gamma, - const float epsilon, - NVTETensor z, - NVTETensor rsigma, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, - NVTETensor barrier); +void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, + NVTETensor rsigma, cudaStream_t stream, const int multiprocessorCount, + NVTETensor workspace, NVTETensor barrier); /*! \brief Compute RMSNorm with zero-centered gamma on the input. * @@ -79,15 +73,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x, * \param[out] workspace Workspace tensor. * \param[out] barrier Barrier tensor. */ -void nvte_rmsnorm1p_fwd(const NVTETensor x, - const NVTETensor gamma, - const float epsilon, - NVTETensor z, - NVTETensor rsigma, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, - NVTETensor barrier); +void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, + NVTETensor z, NVTETensor rsigma, cudaStream_t stream, + const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); /*! \brief Compute backward of RMSNorm. * @@ -118,18 +106,10 @@ void nvte_rmsnorm1p_fwd(const NVTETensor x, * \param[out] workspace Workspace tensor. * \param[out] barrier Barrier tensor. */ -void nvte_rmsnorm_bwd(const NVTETensor dz, - const NVTETensor x, - const NVTETensor rsigma, - const NVTETensor gamma, - NVTETensor dx, - NVTETensor dgamma, - NVTETensor dgamma_part, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, - NVTETensor barrier -); +void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, + const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, + NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, + NVTETensor workspace, NVTETensor barrier); /*! \brief Compute backward of RMSNorm with zero-centered gamma. * @@ -160,18 +140,10 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, * \param[out] workspace Workspace tensor. * \param[out] barrier Barrier tensor. */ -void nvte_rmsnorm1p_bwd(const NVTETensor dz, - const NVTETensor x, - const NVTETensor rsigma, - const NVTETensor gamma, - NVTETensor dx, - NVTETensor dgamma, - NVTETensor dgamma_part, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, - NVTETensor barrier -); +void nvte_rmsnorm1p_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, + const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, + NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, + NVTETensor workspace, NVTETensor barrier); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/softmax.h b/transformer_engine/common/include/transformer_engine/softmax.h index 50f0a006ee8bb0924813e49857418d07c94577b0..6a6fc15fa67cd10214a73eda373faf801daf0aeb 100644 --- a/transformer_engine/common/include/transformer_engine/softmax.h +++ b/transformer_engine/common/include/transformer_engine/softmax.h @@ -7,8 +7,9 @@ #ifndef TRANSFORMER_ENGINE_SOFTMAX_H_ #define TRANSFORMER_ENGINE_SOFTMAX_H_ -#include #include +#include + #include "transformer_engine.h" #ifdef __cplusplus @@ -22,13 +23,8 @@ extern "C" { * \param[in] scale_factor Scalar for the input tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_softmax_forward( - const NVTETensor input, - NVTETensor softmax_results, - float scale_factor, - cudaStream_t stream -); - +void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_results, + float scale_factor, cudaStream_t stream); /*! \brief Compute the backward of the scaled softmax activation. * @@ -42,14 +38,8 @@ void nvte_scaled_softmax_forward( * \param[in] scale_factor Scalar for the output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_softmax_backward( - const NVTETensor incoming_grads, - const NVTETensor softmax_results, - NVTETensor output_grads, - float scale_factor, - cudaStream_t stream -); - +void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, cudaStream_t stream); /*! \brief Compute scaled masked softmax activation on the input. * @@ -59,14 +49,9 @@ void nvte_scaled_softmax_backward( * \param[in] scale_factor Scalar for the input tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_masked_softmax_forward( - const NVTETensor input, - const NVTETensor mask, - NVTETensor softmax_results, - float scale_factor, - cudaStream_t stream -); - +void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask, + NVTETensor softmax_results, float scale_factor, + cudaStream_t stream); /*! \brief Compute the backward of the scaled masked softmax activation. * @@ -80,14 +65,9 @@ void nvte_scaled_masked_softmax_forward( * \param[in] scale_factor Scalar for the output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_masked_softmax_backward( - const NVTETensor incoming_grads, - const NVTETensor softmax_results, - NVTETensor output_grads, - float scale_factor, - cudaStream_t stream -); - +void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, NVTETensor output_grads, + float scale_factor, cudaStream_t stream); /*! \brief Compute scaled softmax activation using a 2D upper triangular mask on the input. * @@ -96,13 +76,9 @@ void nvte_scaled_masked_softmax_backward( * \param[in] scale_factor Scalar for the input tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_upper_triang_masked_softmax_forward( - const NVTETensor input, - NVTETensor softmax_results, - float scale_factor, - cudaStream_t stream -); - +void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input, + NVTETensor softmax_results, float scale_factor, + cudaStream_t stream); /*! \brief Compute the backward of the scaled softmax activation using a 2D upper triangular mask. * @@ -116,14 +92,10 @@ void nvte_scaled_upper_triang_masked_softmax_forward( * \param[in] scale_factor Scalar for the output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_upper_triang_masked_softmax_backward( - const NVTETensor incoming_grads, - const NVTETensor softmax_results, - NVTETensor output_grads, - float scale_factor, - cudaStream_t stream -); - +void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + cudaStream_t stream); /*! \brief Compute scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix. * @@ -132,13 +104,9 @@ void nvte_scaled_upper_triang_masked_softmax_backward( * \param[in] scale_factor Scalar for the input tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_aligned_causal_masked_softmax_forward( - const NVTETensor input, - NVTETensor softmax_results, - float scale_factor, - cudaStream_t stream -); - +void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input, + NVTETensor softmax_results, + float scale_factor, cudaStream_t stream); /*! \brief Compute the backward pass of the scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix. * @@ -152,13 +120,10 @@ void nvte_scaled_aligned_causal_masked_softmax_forward( * \param[in] scale_factor Scalar for the output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_aligned_causal_masked_softmax_backward( - const NVTETensor incoming_grads, - const NVTETensor softmax_results, - NVTETensor output_grads, - float scale_factor, - cudaStream_t stream -); +void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 6a76893ce116aacfef96d3d4aeb3798e15e79452..534f6b2181e7acd61763693a48dd19c95da9f39c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -11,8 +11,8 @@ #ifndef TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ #define TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ -#include #include +#include #ifdef __cplusplus extern "C" { @@ -22,15 +22,15 @@ extern "C" { * \brief TE datatype. */ enum NVTEDType { - kNVTEByte = 0, /*!< Byte */ - kNVTEInt32 = 1, /*!< 32-bit integer */ - kNVTEInt64 = 2, /*!< 32-bit integer */ - kNVTEFloat32 = 3, /*!< 32-bit float */ - kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */ - kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ - kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ - kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ - kNVTENumTypes /*!< Number of supported types */ + kNVTEByte = 0, /*!< Byte */ + kNVTEInt32 = 1, /*!< 32-bit integer */ + kNVTEInt64 = 2, /*!< 32-bit integer */ + kNVTEFloat32 = 3, /*!< 32-bit float */ + kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */ + kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ + kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ + kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ + kNVTENumTypes /*!< Number of supported types */ }; /*! \struct NVTEShape @@ -49,7 +49,7 @@ struct NVTEShape { * to data of a given shape and type. It does not own the * memory it points to. */ -typedef void* NVTETensor; +typedef void *NVTETensor; /*! \brief Create a new TE tensor. * @@ -66,12 +66,8 @@ typedef void* NVTETensor; * * \return A new TE tensor. */ -NVTETensor nvte_create_tensor(void *dptr, - const NVTEShape shape, - const NVTEDType dtype, - float *amax_dptr, - float *scale_dptr, - float *scale_inv_dptr); +NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, + float *amax_dptr, float *scale_dptr, float *scale_inv_dptr); /*! \brief Destroy a TE tensor. * @@ -144,11 +140,11 @@ struct NVTETensorPack { /*! \brief Create `tensors` in NVTETensorPack. */ -void nvte_tensor_pack_create(NVTETensorPack* pack); +void nvte_tensor_pack_create(NVTETensorPack *pack); /*! \brief Destroy `tensors` in NVTETensorPack. */ -void nvte_tensor_pack_destroy(NVTETensorPack* pack); +void nvte_tensor_pack_destroy(NVTETensorPack *pack); #ifdef __cplusplus } // extern "C" @@ -164,12 +160,12 @@ namespace transformer_engine { * \brief TE datatype. */ enum class DType { - kByte = 0, - kInt32 = 1, - kInt64 = 2, - kFloat32 = 3, - kFloat16 = 4, - kBFloat16 = 5, + kByte = 0, + kInt32 = 1, + kInt64 = 2, + kFloat32 = 3, + kFloat16 = 4, + kBFloat16 = 5, kFloat8E4M3 = 6, kFloat8E5M2 = 7, kNumTypes @@ -193,11 +189,10 @@ class TensorWrapper { * \param[in] scale_dptr Pointer to the scale value. * \param[in] scale_inv_dptr Pointer to the inverse of scale value. */ - TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, - float *amax_dptr = nullptr, float *scale_dptr = nullptr, - float *scale_inv_dptr = nullptr) : - tensor_(nvte_create_tensor(dptr, shape, static_cast(dtype), - amax_dptr, scale_dptr, scale_inv_dptr)) {} + TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr, + float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr) + : tensor_(nvte_create_tensor(dptr, shape, static_cast(dtype), amax_dptr, + scale_dptr, scale_inv_dptr)) {} /*! \brief Constructs new TensorWrapper. * @@ -214,9 +209,9 @@ class TensorWrapper { */ TensorWrapper(void *dptr, const std::vector &shape, const DType dtype, float *amax_dptr = nullptr, float *scale_dptr = nullptr, - float *scale_inv_dptr = nullptr) : - TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, - amax_dptr, scale_dptr, scale_inv_dptr) {} + float *scale_inv_dptr = nullptr) + : TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, + scale_inv_dptr) {} /*! \brief Constructs new empty TensorWrapper. * @@ -225,11 +220,9 @@ class TensorWrapper { TensorWrapper() : TensorWrapper(nullptr, std::vector(), DType::kFloat32) {} /*! \brief TensorWrapper destructor. */ - ~TensorWrapper() { - nvte_destroy_tensor(tensor_); - } + ~TensorWrapper() { nvte_destroy_tensor(tensor_); } - TensorWrapper& operator=(const TensorWrapper &other) = delete; + TensorWrapper &operator=(const TensorWrapper &other) = delete; TensorWrapper(const TensorWrapper &other) = delete; /*! \brief Constructs new TensorWrapper from existing TensorWrapper. @@ -249,7 +242,7 @@ class TensorWrapper { * * \param[in,out] other The source of the data. */ - TensorWrapper& operator=(TensorWrapper &&other) { + TensorWrapper &operator=(TensorWrapper &&other) { if (this == &other) return *this; nvte_destroy_tensor(tensor_); tensor_ = other.tensor_; @@ -261,9 +254,7 @@ class TensorWrapper { * * \return NVTETensor held by this TensorWrapper. */ - NVTETensor data() const noexcept { - return tensor_; - } + NVTETensor data() const noexcept { return tensor_; } /*! \brief Get the shape of this TensorWrapper. * diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index 0d55be5d40ff2a1958769b7726ea66a43ba1f58e..ef3d344b059214da6733b8f0b62567060149ad61 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -28,10 +28,8 @@ extern "C" { * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. * \param[in] stream CUDA stream used for the operation. */ -void nvte_cast_transpose(const NVTETensor input, - NVTETensor cast_output, - NVTETensor transposed_output, - cudaStream_t stream); +void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output, + NVTETensor transposed_output, cudaStream_t stream); /*! \brief Transpose the input. * @@ -39,9 +37,7 @@ void nvte_cast_transpose(const NVTETensor input, * \param[out] transposed_output Result of the transpose. Shape: [H, N]. * \param[in] stream CUDA stream used for the operation. */ -void nvte_transpose(const NVTETensor input, - NVTETensor transposed_output, - cudaStream_t stream); +void nvte_transpose(const NVTETensor input, NVTETensor transposed_output, cudaStream_t stream); /*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension. * @@ -61,11 +57,8 @@ void nvte_transpose(const NVTETensor input, * \param[out] workspace Workspace tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_cast_transpose_dbias(const NVTETensor input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output, + NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension. @@ -84,11 +77,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input, * \param[out] workspace Workspace tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fp8_transpose_dbias(const NVTETensor input, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream); +void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Cast and transpose multiple tensors. * @@ -105,10 +95,8 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, * of tensors in input_list. * \param[in] stream CUDA stream used for the operation. */ -void nvte_multi_cast_transpose(size_t num_tensors, - const NVTETensor* input_list, - NVTETensor* cast_output_list, - NVTETensor* transposed_output_list, +void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, + NVTETensor* cast_output_list, NVTETensor* transposed_output_list, cudaStream_t stream); /*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally, @@ -131,50 +119,29 @@ void nvte_multi_cast_transpose(size_t num_tensors, * first dimension. Shape: [H]. * \param[out] workspace Workspace tensor. * \param[in] stream CUDA stream used for the operation. - + Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ -void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream); - -void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream); - -void nvte_cast_transpose_dbias_drelu(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream); - -void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream); - -void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream); +void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + +void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + +void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + +void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output. * @@ -189,38 +156,28 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N]. * \param[in] stream CUDA stream used for the operation. - Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU + Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ -void nvte_dgeglu_cast_transpose(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, +void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream); -void nvte_dswiglu_cast_transpose(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, - cudaStream_t stream); +void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, + cudaStream_t stream); -void nvte_dreglu_cast_transpose(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, +void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream); -void nvte_dqgeglu_cast_transpose(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, - cudaStream_t stream); +void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, + cudaStream_t stream); -void nvte_dsreglu_cast_transpose(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, - cudaStream_t stream); +void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/layer_norm/ln.h b/transformer_engine/common/layer_norm/ln.h index 2e3788e7770cdaf1fa10ff7d38fca06962c56f99..45839ed75b87eaa6b13f57393223384670cc5900 100644 --- a/transformer_engine/common/layer_norm/ln.h +++ b/transformer_engine/common/layer_norm/ln.h @@ -8,11 +8,12 @@ #define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ #include + #include #include #include -#include #include +#include #include "../common.h" @@ -21,113 +22,107 @@ namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct LaunchParams{ - size_t workspace_bytes; - size_t barrier_size; +template +struct LaunchParams { + size_t workspace_bytes; + size_t barrier_size; - int multiprocessorCount; - cudaStream_t stream; + int multiprocessorCount; + cudaStream_t stream; - Params params; + Params params; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct ParamsBase { - ParamsBase() - : ctas_per_col(0) - , rows(0) - , cols(0) - , x(nullptr) - , mu(nullptr) - , rs(nullptr) - , gamma(nullptr) - , workspace(nullptr) - , barrier(nullptr) - , zero_centered_gamma(false) {} - - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - // Size of CTA group. - int ctas_per_row; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x; - void *mu; - void *rs; - void *gamma; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; - - // Whether gamma is centered around 0 - bool zero_centered_gamma; + ParamsBase() + : ctas_per_col(0), + rows(0), + cols(0), + x(nullptr), + mu(nullptr), + rs(nullptr), + gamma(nullptr), + workspace(nullptr), + barrier(nullptr), + zero_centered_gamma(false) {} + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + // Size of CTA group. + int ctas_per_row; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void *x; + void *mu; + void *rs; + void *gamma; + + // Multi-CTA workspace in gmem. + void *workspace; + + // Multi-CTA sync barriers in gmem. + int *barrier; + + // Whether gamma is centered around 0 + bool zero_centered_gamma; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct FwdParams : public ParamsBase { - FwdParams() - : ParamsBase() - , z(nullptr) - , beta(nullptr) - , epsilon(0.f) - , fp8_out(false) {} - - // Output of LN FWD. - void *z; - void *beta; - float epsilon; - - // Scaling factor - void *scale; - - // AMax output - void *amax; - - // Whether to compute scale and amax - bool fp8_out; + FwdParams() : ParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {} + + // Output of LN FWD. + void *z; + void *beta; + float epsilon; + + // Scaling factor + void *scale; + + // AMax output + void *amax; + + // Whether to compute scale and amax + bool fp8_out; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct BwdParams : public ParamsBase { - BwdParams() - : ParamsBase() - , dz(nullptr) - , dbeta_part(nullptr) - , dgamma_part(nullptr) - , dx(nullptr) - , dbeta(nullptr) - , dgamma(nullptr) {} - - // Input: gradient wrt. LN FWD output. - void *dz; - - // Workspace for Wgrad pre-reduction. - void *dbeta_part; - void *dgamma_part; - - // Output: Dgrad. - void *dx; - // Output: Wgrad. - void *dbeta; - void *dgamma; + BwdParams() + : ParamsBase(), + dz(nullptr), + dbeta_part(nullptr), + dgamma_part(nullptr), + dx(nullptr), + dbeta(nullptr), + dgamma(nullptr) {} + + // Input: gradient wrt. LN FWD output. + void *dz; + + // Workspace for Wgrad pre-reduction. + void *dbeta_part; + void *dgamma_part; + + // Output: Dgrad. + void *dx; + // Output: Wgrad. + void *dbeta; + void *dgamma; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -using FwdFunction = std::function&, const bool)>; -using BwdFunction = std::function&, const bool)>; +using FwdFunction = std::function &, const bool)>; +using BwdFunction = std::function &, const bool)>; using FunctionKey = uint64_t; using FwdTunedRegistry = std::unordered_map; using BwdTunedRegistry = std::unordered_map; @@ -141,96 +136,96 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct TypeId{}; +template +struct TypeId {}; -template<> -struct TypeId{ - constexpr static uint32_t Value = 0; +template <> +struct TypeId { + constexpr static uint32_t Value = 0; }; -template<> -struct TypeId{ - constexpr static uint32_t Value = 1; +template <> +struct TypeId { + constexpr static uint32_t Value = 1; }; -template<> -struct TypeId{ - constexpr static uint32_t Value = 2; +template <> +struct TypeId { + constexpr static uint32_t Value = 2; }; -template<> -struct TypeId{ - constexpr static uint32_t Value = 3; +template <> +struct TypeId { + constexpr static uint32_t Value = 3; }; -template -struct Type2Key{ - constexpr static uint32_t Value = TypeId::Value << S; +template +struct Type2Key { + constexpr static uint32_t Value = TypeId::Value << S; }; -template -struct WeightType2Key : public Type2Key{}; +template +struct WeightType2Key : public Type2Key {}; -template -struct InputType2Key : public Type2Key{}; +template +struct InputType2Key : public Type2Key {}; -template -struct OutputType2Key : public Type2Key{}; +template +struct OutputType2Key : public Type2Key {}; -template -struct ComputeType2Key : public Type2Key{}; +template +struct ComputeType2Key : public Type2Key {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Types2Key{ - constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | - OutputType2Key::Value | ComputeType2Key::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size){ - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } +template +struct Types2Key { + constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | + OutputType2Key::Value | ComputeType2Key::Value; + constexpr static inline uint64_t get(const uint64_t hidden_size) { + constexpr uint64_t type_key = Value; + return (type_key << 32) | hidden_size; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct FwdTunedRegistrar{ - explicit FwdTunedRegistrar(FwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({ key, f }); - } +template +struct FwdTunedRegistrar { + explicit FwdTunedRegistrar(FwdFunction f) { + uint64_t key = Types2Key::get(HIDDEN_SIZE); + FWD_TUNED_FUNCS.insert({key, f}); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct FwdGeneralRegistrar{ - explicit FwdGeneralRegistrar(FwdFunction f){ - uint64_t key = Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f }); - } +template +struct FwdGeneralRegistrar { + explicit FwdGeneralRegistrar(FwdFunction f) { + uint64_t key = Types2Key::get(0); + FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct BwdTunedRegistrar{ - explicit BwdTunedRegistrar(BwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({ key, f }); - } +template +struct BwdTunedRegistrar { + explicit BwdTunedRegistrar(BwdFunction f) { + uint64_t key = Types2Key::get(HIDDEN_SIZE); + BWD_TUNED_FUNCS.insert({key, f}); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct BwdGeneralRegistrar{ - explicit BwdGeneralRegistrar(BwdFunction f){ - uint64_t key = Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f }); - } +template +struct BwdGeneralRegistrar { + explicit BwdGeneralRegistrar(BwdFunction f) { + uint64_t key = Types2Key::get(0); + BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); + } }; ////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp index 7a01cf034538afcd1dad31de15c57c45679b3de2..115422e94ed493fe28095cda648728cb13917f2a 100644 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ b/transformer_engine/common/layer_norm/ln_api.cpp @@ -9,8 +9,8 @@ #include #include -#include "ln.h" #include "../common.h" +#include "ln.h" /* @@ -46,500 +46,411 @@ BwdGeneralRegistry BWD_GENERAL_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// uint32_t get_type_id(DType dtype) { - if ( dtype == DType::kFloat16 ) { - return TypeId::Value; - } else if ( dtype == DType::kBFloat16 ) { - return TypeId::Value; - } else if ( dtype == DType::kFloat32 ) { - return TypeId::Value; - } else if ( dtype == DType::kFloat8E4M3 ) { - return TypeId::Value; - } else { - NVTE_ERROR("Type not supported."); - } + if (dtype == DType::kFloat16) { + return TypeId::Value; + } else if (dtype == DType::kBFloat16) { + return TypeId::Value; + } else if (dtype == DType::kFloat32) { + return TypeId::Value; + } else if (dtype == DType::kFloat8E4M3) { + return TypeId::Value; + } else { + NVTE_ERROR("Type not supported."); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) { - using namespace layer_norm; - uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | - (get_type_id(otype) << 4) | (get_type_id(ctype) << 6); - uint64_t launcher_key = (type_key << 32) | hidden_size; - return launcher_key; + using namespace layer_norm; + uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | + (get_type_id(ctype) << 6); + uint64_t launcher_key = (type_key << 32) | hidden_size; + return launcher_key; } //////////////////////////////////////////////////////////////////////////////////////////////////// -layer_norm::FwdFunction & get_fwd_launcher(DType wtype, - DType itype, - DType otype, - DType ctype, - const layer_norm::FwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 - && is_aligned(params.x) - && is_aligned(params.mu) - && is_aligned(params.rs) - && is_aligned(params.gamma) - && is_aligned(params.beta) - && is_aligned(params.z) - && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } +layer_norm::FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, + const layer_norm::FwdParams& params) { + // Look for tuned kernel + auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); + auto is_aligned = [](const void* ptr) -> bool { + // Assume vectorized memory accesses are <=16B + return reinterpret_cast(ptr) % 16 == 0; + }; + if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && + is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.beta) && + is_aligned(params.z) && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) { + return layer_norm::FWD_TUNED_FUNCS.at(tuned_key); + } + + // Pick general kernel + auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); + if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) { + NVTE_ERROR("FWD: Unsupported types."); + } + auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key); + auto func_iter = general_func_map.lower_bound(params.cols); + if (func_iter == general_func_map.end()) { + // Hidden size is too big, need to use multi-CTA + return general_func_map.rbegin()->second; + } else { + return func_iter->second; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -layer_norm::BwdFunction & get_bwd_launcher(DType wtype, - DType itype, - DType otype, - DType ctype, - const layer_norm::BwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 - && is_aligned(params.x) - && is_aligned(params.mu) - && is_aligned(params.rs) - && is_aligned(params.gamma) - && is_aligned(params.dz) - && is_aligned(params.dx) - && is_aligned(params.dbeta) - && is_aligned(params.dgamma) - && is_aligned(params.dbeta_part) - && is_aligned(params.dgamma_part) - && layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } +layer_norm::BwdFunction& get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, + const layer_norm::BwdParams& params) { + // Look for tuned kernel + auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); + auto is_aligned = [](const void* ptr) -> bool { + // Assume vectorized memory accesses are <=16B + return reinterpret_cast(ptr) % 16 == 0; + }; + if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && + is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.dz) && + is_aligned(params.dx) && is_aligned(params.dbeta) && is_aligned(params.dgamma) && + is_aligned(params.dbeta_part) && is_aligned(params.dgamma_part) && + layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { + return layer_norm::BWD_TUNED_FUNCS.at(tuned_key); + } + + // Pick general kernel + auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); + if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) { + NVTE_ERROR("BWD: Unsupported types."); + } + auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key); + auto func_iter = general_func_map.lower_bound(params.cols); + if (func_iter == general_func_map.end()) { + // Hidden size is too big, need to use multi-CTA + return general_func_map.rbegin()->second; + } else { + return func_iter->second; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// - -size_t product(const std::vector &shape) { - size_t ret = 1; - for (auto s : shape) { - ret *= s; - } - return ret; +size_t product(const std::vector& shape) { + size_t ret = 1; + for (auto s : shape) { + ret *= s; + } + return ret; } } // namespace layer_norm //////////////////////////////////////////////////////////////////////////////////////////////////// -void layernorm_fwd(const Tensor& x, // BxSxhidden_size - const Tensor& gamma, // hidden_size - const Tensor& beta, // hidden_size - const float epsilon, - Tensor* z, - Tensor* mu, - Tensor* rsigma, - cudaStream_t stream, - const int multiprocessorCount, - Tensor* workspace, - Tensor* barrier, +void layernorm_fwd(const Tensor& x, // BxSxhidden_size + const Tensor& gamma, // hidden_size + const Tensor& beta, // hidden_size + const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, cudaStream_t stream, + const int multiprocessorCount, Tensor* workspace, Tensor* barrier, const bool zero_centered_gamma) { - const auto itype = x.data.dtype; - const auto wtype = gamma.data.dtype; - const auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - const auto ctype = layer_norm::DType::kFloat32; - - NVTE_CHECK(x.data.shape.size() == 2); - - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(gamma.data.shape == beta.data.shape); - NVTE_CHECK(hidden_size == cols); - - NVTE_CHECK(epsilon >= 0.f); - - NVTE_CHECK(z->data.shape == x.data.shape); - - NVTE_CHECK(mu->data.shape == std::vector{ rows }); - NVTE_CHECK(mu->data.dtype == ctype); - - NVTE_CHECK(rsigma->data.shape == std::vector{ rows }); - NVTE_CHECK(rsigma->data.dtype == ctype); - - layer_norm::LaunchParams launch_params; - - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; - - // Set the kernel runtime parameters. - layer_norm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu->data.dptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = beta.data.dptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } - - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = { launch_params.workspace_bytes }; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = { launch_params.barrier_size }; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - CheckInputTensor(beta, "beta"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*mu, "mu"); - CheckOutputTensor(*rsigma, "rsigma"); - - if ( launch_params.barrier_size > 0 ) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if ( params.fp8_out ) { - cudaMemsetAsync(params.amax, 0, - layer_norm::product(z->amax.shape) * - typeToSize(z->amax.dtype), stream); - } - if ( launch_params.barrier_size > 0 ) { - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * - typeToSize(barrier->data.dtype), stream); - } - - // Launch the kernel. - launcher(launch_params, false); + const auto itype = x.data.dtype; + const auto wtype = gamma.data.dtype; + const auto otype = z->data.dtype; + const bool fp8_out = is_fp8_dtype(otype); + const auto ctype = layer_norm::DType::kFloat32; + + NVTE_CHECK(x.data.shape.size() == 2); + + const size_t rows = x.data.shape[0]; + const size_t cols = x.data.shape[1]; + const auto hidden_size = gamma.data.shape[0]; + + NVTE_CHECK(gamma.data.shape == beta.data.shape); + NVTE_CHECK(hidden_size == cols); + + NVTE_CHECK(epsilon >= 0.f); + + NVTE_CHECK(z->data.shape == x.data.shape); + + NVTE_CHECK(mu->data.shape == std::vector{rows}); + NVTE_CHECK(mu->data.dtype == ctype); + + NVTE_CHECK(rsigma->data.shape == std::vector{rows}); + NVTE_CHECK(rsigma->data.dtype == ctype); + + layer_norm::LaunchParams launch_params; + + launch_params.multiprocessorCount = multiprocessorCount; + launch_params.stream = stream; + + // Set the kernel runtime parameters. + layer_norm::FwdParams& params = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data.dptr; + params.mu = mu->data.dptr; + params.rs = rsigma->data.dptr; + params.gamma = gamma.data.dptr; + params.beta = beta.data.dptr; + params.z = z->data.dptr; + params.epsilon = epsilon; + params.amax = z->amax.dptr; + params.scale = z->scale.dptr; + params.fp8_out = fp8_out; + params.zero_centered_gamma = zero_centered_gamma; + + // Request the kernel launcher. + auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params); + + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + if (launch_params.workspace_bytes == 0) { + launch_params.workspace_bytes = 1; + } + + if (workspace->data.dptr == nullptr) { + NVTE_CHECK(barrier->data.dptr == nullptr); + + workspace->data.dtype = layer_norm::DType::kByte; + workspace->data.shape = {launch_params.workspace_bytes}; + + barrier->data.dtype = layer_norm::DType::kInt32; + barrier->data.shape = {launch_params.barrier_size}; return; + } else { + NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); + } + + // Tensor checks are delayed here in order to recover workspace sizes with null data + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + CheckInputTensor(beta, "beta"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*mu, "mu"); + CheckOutputTensor(*rsigma, "rsigma"); + + if (launch_params.barrier_size > 0) { + params.workspace = workspace->data.dptr; + params.barrier = reinterpret_cast(barrier->data.dptr); + } + + // Clear buffers + if (params.fp8_out) { + cudaMemsetAsync(params.amax, 0, layer_norm::product(z->amax.shape) * typeToSize(z->amax.dtype), + stream); + } + if (launch_params.barrier_size > 0) { + cudaMemsetAsync(params.barrier, 0, + layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), + stream); + } + + // Launch the kernel. + launcher(launch_params, false); + + return; } -void layernorm_bwd(const Tensor& dz, - const Tensor& x, - const Tensor& mu, - const Tensor& rsigma, - const Tensor& gamma, - Tensor* dx, - Tensor* dgamma, - Tensor* dbeta, - Tensor* dgamma_part, - Tensor* dbeta_part, - cudaStream_t stream, - const int multiprocessorCount, - Tensor* workspace, - Tensor* barrier, - const bool zero_centered_gamma -) { - using namespace transformer_engine; - - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(mu.data.dtype == ctype); - NVTE_CHECK(rsigma.data.dtype == ctype); - - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(dz.data.shape == x.data.shape); - auto rows = x.data.shape[0]; - auto cols = x.data.shape[1]; - - auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(mu.data.shape[0] == rows); - NVTE_CHECK(mu.data.shape == rsigma.data.shape); - - NVTE_CHECK(gamma.data.shape[0] == cols); - - NVTE_CHECK(dx->data.shape == x.data.shape); - NVTE_CHECK(dx->data.dtype == x.data.dtype); - - NVTE_CHECK(dgamma->data.shape == gamma.data.shape); - NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - - NVTE_CHECK(dbeta->data.shape == gamma.data.shape); - NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); - - layer_norm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - layer_norm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu.data.dptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = dbeta->data.dptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = dbeta_part->data.dptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - NVTE_CHECK(dbeta_part->data.dptr == nullptr); - - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = { static_cast (launch_params.params.ctas_per_col), - hidden_size }; - - dbeta_part->data.dtype = ctype; - dbeta_part->data.shape = { static_cast (launch_params.params.ctas_per_col), - hidden_size }; - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = { launch_params.workspace_bytes }; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = { launch_params.barrier_size }; - - return; - } else { - NVTE_CHECK(dbeta_part->data.dptr != nullptr); - auto pdw_shape = std::vector{ - static_cast(launch_params.params.ctas_per_col), hidden_size}; - - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - NVTE_CHECK(dbeta_part->data.dtype == ctype); - NVTE_CHECK(dbeta_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); - } - - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(mu, "mu"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - CheckOutputTensor(*dbeta, "dbeta"); - - if ( launch_params.barrier_size > 0 ) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * - typeToSize(barrier->data.dtype), stream); - } - - // Launch the kernel. - launcher(launch_params, false); +void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma, + const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta, + Tensor* dgamma_part, Tensor* dbeta_part, cudaStream_t stream, + const int multiprocessorCount, Tensor* workspace, Tensor* barrier, + const bool zero_centered_gamma) { + using namespace transformer_engine; + + auto itype = x.data.dtype; + auto wtype = gamma.data.dtype; + auto otype = wtype; + auto ctype = DType::kFloat32; + + NVTE_CHECK(dz.data.dtype == otype); + NVTE_CHECK(mu.data.dtype == ctype); + NVTE_CHECK(rsigma.data.dtype == ctype); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + auto rows = x.data.shape[0]; + auto cols = x.data.shape[1]; + + auto hidden_size = gamma.data.shape[0]; + + NVTE_CHECK(mu.data.shape[0] == rows); + NVTE_CHECK(mu.data.shape == rsigma.data.shape); + + NVTE_CHECK(gamma.data.shape[0] == cols); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + NVTE_CHECK(dbeta->data.shape == gamma.data.shape); + NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); + + layer_norm::LaunchParams launch_params; + launch_params.stream = stream; + launch_params.multiprocessorCount = multiprocessorCount; + + // Set the kernel runtime parameters. + layer_norm::BwdParams& params = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data.dptr; + params.mu = mu.data.dptr; + params.rs = rsigma.data.dptr; + params.gamma = gamma.data.dptr; + params.dz = dz.data.dptr; + params.dx = dx->data.dptr; + params.dbeta = dbeta->data.dptr; + params.dgamma = dgamma->data.dptr; + params.dbeta_part = dbeta_part->data.dptr; + params.dgamma_part = dgamma_part->data.dptr; + params.zero_centered_gamma = zero_centered_gamma; + + auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params); + + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + + // Populate shape and dtypes for FW to allocate memory + if (dgamma_part->data.dptr == nullptr) { + NVTE_CHECK(dbeta_part->data.dptr == nullptr); + + dgamma_part->data.dtype = ctype; + dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), + hidden_size}; + + dbeta_part->data.dtype = ctype; + dbeta_part->data.shape = {static_cast(launch_params.params.ctas_per_col), + hidden_size}; + + workspace->data.dtype = layer_norm::DType::kByte; + workspace->data.shape = {launch_params.workspace_bytes}; + + barrier->data.dtype = layer_norm::DType::kInt32; + barrier->data.shape = {launch_params.barrier_size}; + + return; + } else { + NVTE_CHECK(dbeta_part->data.dptr != nullptr); + auto pdw_shape = + std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; + + NVTE_CHECK(dgamma_part->data.dtype == ctype); + NVTE_CHECK(dgamma_part->data.shape == pdw_shape); + NVTE_CHECK(dbeta_part->data.dtype == ctype); + NVTE_CHECK(dbeta_part->data.shape == pdw_shape); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); + } + + if (launch_params.workspace_bytes > 0) { + NVTE_CHECK(workspace->data.dptr != nullptr); + NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); + } + + // Tensor checks are delayed here in order to recover workspace sizes with null data + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(mu, "mu"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + CheckOutputTensor(*dbeta, "dbeta"); + + if (launch_params.barrier_size > 0) { + params.workspace = workspace->data.dptr; + params.barrier = reinterpret_cast(barrier->data.dptr); + cudaMemsetAsync(params.barrier, 0, + layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), + stream); + } + + // Launch the kernel. + launcher(launch_params, false); } } // namespace transformer_engine -void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size - const NVTETensor gamma, // hidden_size - const NVTETensor beta, // hidden_size - const float epsilon, - NVTETensor z, - NVTETensor mu, - NVTETensor rsigma, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, +void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size + const NVTETensor gamma, // hidden_size + const NVTETensor beta, // hidden_size + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { NVTE_API_CALL(nvte_layernorm_fwd); using namespace transformer_engine; - layernorm_fwd(*reinterpret_cast(x), - *reinterpret_cast(gamma), - *reinterpret_cast(beta), - epsilon, - reinterpret_cast(z), - reinterpret_cast(mu), - reinterpret_cast(rsigma), - stream, - multiprocessorCount, - reinterpret_cast(workspace), - reinterpret_cast(barrier), - false); + layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), + *reinterpret_cast(beta), epsilon, reinterpret_cast(z), + reinterpret_cast(mu), reinterpret_cast(rsigma), stream, + multiprocessorCount, reinterpret_cast(workspace), + reinterpret_cast(barrier), false); } -void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, - NVTETensor dgamma, - NVTETensor dbeta, - NVTETensor dgamma_part, - NVTETensor dbeta_part, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, - NVTETensor barrier) { +void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size + const NVTETensor x, // BxSxhidden_size + const NVTETensor mu, // BxS, FP32! + const NVTETensor rsigma, // BxS, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part, + NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount, + NVTETensor workspace, NVTETensor barrier) { NVTE_API_CALL(nvte_layernorm_bwd); using namespace transformer_engine; - layernorm_bwd(*reinterpret_cast(dz), - *reinterpret_cast(x), - *reinterpret_cast(mu), - *reinterpret_cast(rsigma), - *reinterpret_cast(gamma), - reinterpret_cast(dx), - reinterpret_cast(dgamma), - reinterpret_cast(dbeta), - reinterpret_cast(dgamma_part), - reinterpret_cast(dbeta_part), - stream, - multiprocessorCount, - reinterpret_cast(workspace), - reinterpret_cast(barrier), - false); + layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), + *reinterpret_cast(mu), *reinterpret_cast(rsigma), + *reinterpret_cast(gamma), reinterpret_cast(dx), + reinterpret_cast(dgamma), reinterpret_cast(dbeta), + reinterpret_cast(dgamma_part), reinterpret_cast(dbeta_part), + stream, multiprocessorCount, reinterpret_cast(workspace), + reinterpret_cast(barrier), false); } -void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size - const NVTETensor gamma, // hidden_size - const NVTETensor beta, // hidden_size - const float epsilon, - NVTETensor z, - NVTETensor mu, - NVTETensor rsigma, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, +void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size + const NVTETensor gamma, // hidden_size + const NVTETensor beta, // hidden_size + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { NVTE_API_CALL(nvte_layernorm1p_fwd); using namespace transformer_engine; - layernorm_fwd(*reinterpret_cast(x), - *reinterpret_cast(gamma), - *reinterpret_cast(beta), - epsilon, - reinterpret_cast(z), - reinterpret_cast(mu), - reinterpret_cast(rsigma), - stream, - multiprocessorCount, - reinterpret_cast(workspace), - reinterpret_cast(barrier), - true); + layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), + *reinterpret_cast(beta), epsilon, reinterpret_cast(z), + reinterpret_cast(mu), reinterpret_cast(rsigma), stream, + multiprocessorCount, reinterpret_cast(workspace), + reinterpret_cast(barrier), true); } -void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, - NVTETensor dgamma, - NVTETensor dbeta, - NVTETensor dgamma_part, - NVTETensor dbeta_part, - cudaStream_t stream, - const int multiprocessorCount, - NVTETensor workspace, - NVTETensor barrier) { +void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size + const NVTETensor x, // BxSxhidden_size + const NVTETensor mu, // BxS, FP32! + const NVTETensor rsigma, // BxS, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, + NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream, + const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { NVTE_API_CALL(nvte_layernorm1p_bwd); using namespace transformer_engine; - layernorm_bwd(*reinterpret_cast(dz), - *reinterpret_cast(x), - *reinterpret_cast(mu), - *reinterpret_cast(rsigma), - *reinterpret_cast(gamma), - reinterpret_cast(dx), - reinterpret_cast(dgamma), - reinterpret_cast(dbeta), - reinterpret_cast(dgamma_part), - reinterpret_cast(dbeta_part), - stream, - multiprocessorCount, - reinterpret_cast(workspace), - reinterpret_cast(barrier), - true); + layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), + *reinterpret_cast(mu), *reinterpret_cast(rsigma), + *reinterpret_cast(gamma), reinterpret_cast(dx), + reinterpret_cast(dgamma), reinterpret_cast(dbeta), + reinterpret_cast(dgamma_part), reinterpret_cast(dbeta_part), + stream, multiprocessorCount, reinterpret_cast(workspace), + reinterpret_cast(barrier), true); } diff --git a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh b/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh index 3c2522cc7ddb4aab6b53debc51ef93ed0511036b..dbd002524436a4bfd70d774a3fe9acc4b1ca6396 100644 --- a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh @@ -7,605 +7,570 @@ #ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ -#include "ln.h" #include "../utils.cuh" +#include "ln.h" namespace transformer_engine { namespace layer_norm { using namespace transformer_engine; -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_bwd_tuned_kernel(layer_norm::BwdParams params) { - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { COLS = Ktraits::COLS }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using compute_t = typename Ktraits::compute_t; - using index_t = typename Ktraits::index_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Reducer = typename Ktraits::Reducer; - using reduce_t = typename Reducer::Type; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / Ktraits::WARPS_N; - const index_t warp_n = warp % Ktraits::WARPS_N; - const index_t tid_r = warp_n * THREADS_PER_WARP + lane; - - const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); - - Cvec dzy_sum[LDGS]; - Cvec dz_sum[LDGS]; - - memset(dzy_sum, 0, sizeof(dzy_sum)); - memset(dz_sum, 0, sizeof(dz_sum)); - - compute_t * smem_wgrad = reinterpret_cast(smem_); - char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; - - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); - - Sum sum; - - constexpr float rn = 1.f / static_cast(COLS); - Wvec gamma[LDGS]; - index_t idx = c; - #pragma unroll - for ( int it = 0; it < LDGS; it++ ) { - gamma[it].load_from(params.gamma, idx); - idx += Ktraits::VEC_COLS_PER_LDG; +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( + layer_norm::BwdParams params) { + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { COLS = Ktraits::COLS }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using compute_t = typename Ktraits::compute_t; + using index_t = typename Ktraits::index_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Reducer = typename Ktraits::Reducer; + using reduce_t = typename Reducer::Type; + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / Ktraits::WARPS_N; + const index_t warp_n = warp % Ktraits::WARPS_N; + const index_t tid_r = warp_n * THREADS_PER_WARP + lane; + + const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); + + Cvec dzy_sum[LDGS]; + Cvec dz_sum[LDGS]; + + memset(dzy_sum, 0, sizeof(dzy_sum)); + memset(dz_sum, 0, sizeof(dz_sum)); + + compute_t *smem_wgrad = reinterpret_cast(smem_); + char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; + + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); + + Sum sum; + + constexpr float rn = 1.f / static_cast(COLS); + Wvec gamma[LDGS]; + index_t idx = c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + gamma[it].load_from(params.gamma, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } +// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the +// last blocks with syncthreads! +// grid stride over rows +#pragma unroll 1 + for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { + const compute_t mu_r = static_cast(params.mu)[row]; + const compute_t rs_r = static_cast(params.rs)[row]; + Ivec x[LDGS]; + Ovec dz[LDGS]; + index_t idx = row * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + dz[it].load_from(params.dz, idx); + x[it].load_from(params.x, idx); + idx += Ktraits::VEC_COLS_PER_LDG; } - // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the - // last blocks with syncthreads! - // grid stride over rows - #pragma unroll 1 - for ( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t mu_r = static_cast(params.mu)[row]; - const compute_t rs_r = static_cast(params.rs)[row]; - Ivec x[LDGS]; - Ovec dz[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; - #pragma unroll - for ( int it = 0; it < LDGS; it++ ) { - dz[it].load_from(params.dz, idx); - x[it].load_from(params.x, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - - compute_t dy[LDGS * NUM_ELTS]; - compute_t y[LDGS * NUM_ELTS]; - - compute_t mdy_local = 0.f; - compute_t mdyy_local = 0.f; - #pragma unroll - for ( int it = 0; it < LDGS; it++ ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - const compute_t x_tmp = x[it].data.elt[jt]; - const compute_t y_tmp = rs_r * (x_tmp - mu_r); - const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; - compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift; - dy_tmp *= compute_t(dz[it].data.elt[jt]); - compute_t dz_tmp = dz[it].data.elt[jt]; - - mdy_local += dy_tmp; - mdyy_local += dy_tmp * y_tmp; - - dy[it * NUM_ELTS + jt] = dy_tmp; - y[it * NUM_ELTS + jt] = y_tmp; - - dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; - dz_sum[it].data.elt[jt] += dz_tmp; - } - } - - reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); - mdy_local = layer_norm::Get<0>::of(result) * rn; - mdyy_local = layer_norm::Get<1>::of(result) * rn; - - Ivec dx[LDGS]; - idx = row * Ktraits::VEC_COLS + c; - #pragma unroll - for ( int it = 0; it < LDGS; it++ ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t dy_tmp = dy[it * NUM_ELTS + jt]; - compute_t y_tmp = y[it * NUM_ELTS + jt]; - compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); - dx[it].data.elt[jt] = dx_tmp; - } - dx[it].store_to(params.dx, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - } // end: grid stride loop - - if ( WARPS_M == 1 ) { - idx = r * Ktraits::VEC_COLS + c; - #pragma unroll - for ( int it = 0; it < LDGS; it++ ) { - dz_sum[it].store_to(params.dbeta_part, idx); - dzy_sum[it].store_to(params.dgamma_part, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - } else { - static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, - "Multiple rows per CTA not supported for Multi-CTA."); - // Finalize reduction of part dgamma and dbeta for this CTA - // by reducing over the rows held across the WARPS_M warps - - // Assumption: blockSize divides hidden size. - enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; - static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for ( int it = 0; it < LDGS; it++ ) { - dz_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dz_sum[NUM_RES]; - memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); - for ( int it = 0; it < ROWS_PER_CTA; it++ ) { - for ( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - __syncthreads(); - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for ( int it = 0; it < LDGS; it++ ) { - dzy_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dzy_sum[NUM_RES]; - memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); - for ( int it = 0; it < ROWS_PER_CTA; it++ ) { - for ( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - - compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; - for ( int jt = 0; jt < NUM_RES; jt++ ) { - *dgamma_part = cta_dzy_sum[jt]; - dgamma_part += Ktraits::THREADS_PER_CTA; - } + compute_t dy[LDGS * NUM_ELTS]; + compute_t y[LDGS * NUM_ELTS]; + + compute_t mdy_local = 0.f; + compute_t mdyy_local = 0.f; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + const compute_t x_tmp = x[it].data.elt[jt]; + const compute_t y_tmp = rs_r * (x_tmp - mu_r); + const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; + compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift; + dy_tmp *= compute_t(dz[it].data.elt[jt]); + compute_t dz_tmp = dz[it].data.elt[jt]; + + mdy_local += dy_tmp; + mdyy_local += dy_tmp * y_tmp; + + dy[it * NUM_ELTS + jt] = dy_tmp; + y[it * NUM_ELTS + jt] = y_tmp; + + dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; + dz_sum[it].data.elt[jt] += dz_tmp; + } + } - compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * COLS + tidx; - for ( int jt = 0; jt < NUM_RES; jt++ ) { - *dbeta_part = cta_dz_sum[jt]; - dbeta_part += Ktraits::THREADS_PER_CTA; - } + reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); + mdy_local = layer_norm::Get<0>::of(result) * rn; + mdyy_local = layer_norm::Get<1>::of(result) * rn; + + Ivec dx[LDGS]; + idx = row * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t dy_tmp = dy[it * NUM_ELTS + jt]; + compute_t y_tmp = y[it * NUM_ELTS + jt]; + compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); + dx[it].data.elt[jt] = dx_tmp; + } + dx[it].store_to(params.dx, idx); + idx += Ktraits::VEC_COLS_PER_LDG; } -} + } // end: grid stride loop + + if (WARPS_M == 1) { + idx = r * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + dz_sum[it].store_to(params.dbeta_part, idx); + dzy_sum[it].store_to(params.dgamma_part, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + } else { + static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, + "Multiple rows per CTA not supported for Multi-CTA."); + // Finalize reduction of part dgamma and dbeta for this CTA + // by reducing over the rows held across the WARPS_M warps + + // Assumption: blockSize divides hidden size. + enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; + static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + dz_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dz_sum[NUM_RES]; + memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); + for (int it = 0; it < ROWS_PER_CTA; it++) { + for (int jt = 0; jt < NUM_RES; jt++) { + cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + __syncthreads(); -template -__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) -void ln_bwd_finalize_tuned_kernel(BwdParams params) { - using compute_t = typename Kernel_traits::compute_t; - using weight_t = typename Kernel_traits::weight_t; - using index_t = typename Kernel_traits::index_t; - using Reducer = typename Kernel_traits::Reducer; - using reduce_t = typename Reducer::Type; - - Sum sum; - enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; - - __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; - - constexpr uint32_t bidm = 0; - - const uint32_t bidn = blockIdx.x; - const uint32_t tidx = threadIdx.x; - const uint32_t warp = tidx / THREADS_PER_WARP; - const uint32_t lane = tidx % THREADS_PER_WARP; - - Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); - - const uint32_t c = bidn * THREADS_PER_WARP + lane; - const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; - for ( uint32_t col = c, col_out = c_out; - col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { - // Each thread sums over NUM_ELT columns. - Vec dbeta_local, dgamma_local; - memset(&dgamma_local, 0, sizeof(dgamma_local)); - memset(&dbeta_local, 0, sizeof(dbeta_local)); - for ( uint32_t row = warp; row < params.ctas_per_col; - row += Kernel_traits::ROWS_PER_CTA ) { - index_t idx = row * Kernel_traits::COLS + col; - - Vec dbeta_part, dgamma_part; - dbeta_part.load_from(params.dbeta_part, idx); - dgamma_part.load_from(params.dgamma_part, idx); - #pragma unroll - for ( int it = 0; it < NUM_ELT; it++ ) { - dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; - dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; - } - } + idx = warp_m * Ktraits::VEC_COLS + tid_r; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + dzy_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dzy_sum[NUM_RES]; + memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); + for (int it = 0; it < ROWS_PER_CTA; it++) { + for (int jt = 0; jt < NUM_RES; jt++) { + cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } - void * smem_gamma = smem_; - void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; - - const int write_row = warp; - const int write_col = lane ^ write_row; - const int write_idx = write_row * THREADS_PER_WARP + write_col; - - dgamma_local.store_to(smem_gamma, write_idx); - dbeta_local.store_to(smem_beta, write_idx); - - __syncthreads(); - - // It would be probably safe to reuse the first row of smem_beta and smem_gamma - void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE - + Kernel_traits::SMEM_BYTES_OUTPUT]; - - // More than one iter iff ROWS_PER_CTA < 32. - for ( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { - const int read_row = lane; - const int read_col = w ^ read_row; - const int read_idx = read_row * THREADS_PER_WARP + read_col; - - memset(&dbeta_local, 0, sizeof(dbeta_local)); - memset(&dgamma_local, 0, sizeof(dgamma_local)); - - // Load beta and gamma transposed - if (read_row < Kernel_traits::ROWS_PER_CTA) { - dbeta_local.load_from(smem_beta, read_idx); - dgamma_local.load_from(smem_gamma, read_idx); - } - - // Call reducer on the loaded value(s) and convert. - #pragma unroll - for ( int it = 0; it < NUM_ELT; it++ ) { - compute_t b_i = dbeta_local.data.elt[it]; - compute_t g_i = dgamma_local.data.elt[it]; - b_i = reducer.allreduce(b_i, sum); - g_i = reducer.allreduce(g_i, sum); - - dgamma_local.data.elt[it] = g_i; - dbeta_local.data.elt[it] = b_i; - } - - // Leader stores the result at the current column. - if (lane == 0) { - dgamma_local.store_to(smem_gamma_out, w); - dbeta_local.store_to(smem_beta_out, w); - } - } + compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; + for (int jt = 0; jt < NUM_RES; jt++) { + *dgamma_part = cta_dzy_sum[jt]; + dgamma_part += Ktraits::THREADS_PER_CTA; + } - // All writes done. - __syncthreads(); - - // Pack and store: 2-wide stores with half the threads. - if ( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { - using src_t = typename TypeToVec2::Type; - using dst_t = typename TypeToVec2::Type; - Vec dbeta_vec2, dgamma_vec2; - Vec dbeta_out2, dgamma_out2; - - dgamma_vec2.load_from(smem_gamma_out, lane); - dbeta_vec2.load_from(smem_beta_out, lane); - #pragma unroll - for ( int it = 0; it < NUM_ELT; it++ ) { - dgamma_out2.data.elt[it] = - Converter::convert(dgamma_vec2.data.elt[it]); - dbeta_out2.data.elt[it] = - Converter::convert(dbeta_vec2.data.elt[it]); - } - dgamma_out2.store_to(params.dgamma, col_out); - dbeta_out2.store_to(params.dbeta, col_out); - } + compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * COLS + tidx; + for (int jt = 0; jt < NUM_RES; jt++) { + *dbeta_part = cta_dz_sum[jt]; + dbeta_part += Ktraits::THREADS_PER_CTA; } + } } -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_bwd_general_kernel(layer_norm::BwdParams params) { - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using compute_t = typename Ktraits::compute_t; - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - - const index_t tidx = threadIdx.x; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t bdimm = WARPS_M; - const index_t bdimn = WARPS_N * THREADS_PER_WARP; - const index_t bidm = blockIdx.x / params.ctas_per_row; - const index_t bidn = blockIdx.x % params.ctas_per_row; - - const index_t gdimm = bdimm * params.ctas_per_col; - const index_t gdimn = bdimn * params.ctas_per_row; - const index_t gidm = bidm * bdimm + warp_m; - const index_t gidn = (bidn * THREADS_PER_WARP - + warp_n * params.ctas_per_row * THREADS_PER_WARP - + lane); // Order threads by warp x cta x lane - - // Objects for weight grads - Cvec dzy_sum[LDGS]; - Cvec dz_sum[LDGS]; - memset(dzy_sum, 0, sizeof(dzy_sum)); - memset(dz_sum, 0, sizeof(dz_sum)); - - // Objects for stats reductions - using reduce_t = typename Ktraits::Reducer::Type; - using Reducer = DynamicReducer; - constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; - __shared__ char smem_[SMEM_BYTES]; - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); - Sum sum; - const compute_t rn = 1.f / static_cast(params.cols); - - // Load weights - Cvec gamma[LDGS]; - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - Wvec gamma_in; - gamma_in.load_from_elts(params.gamma, col, params.cols - col); - gamma_in.to(gamma[it]); +template +__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_tuned_kernel( + BwdParams params) { + using compute_t = typename Kernel_traits::compute_t; + using weight_t = typename Kernel_traits::weight_t; + using index_t = typename Kernel_traits::index_t; + using Reducer = typename Kernel_traits::Reducer; + using reduce_t = typename Reducer::Type; + + Sum sum; + enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; + + __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; + + constexpr uint32_t bidm = 0; + + const uint32_t bidn = blockIdx.x; + const uint32_t tidx = threadIdx.x; + const uint32_t warp = tidx / THREADS_PER_WARP; + const uint32_t lane = tidx % THREADS_PER_WARP; + + Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); + + const uint32_t c = bidn * THREADS_PER_WARP + lane; + const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; + constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; + for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; + col += COL_STRIDE, col_out += COL_STRIDE / 2) { + // Each thread sums over NUM_ELT columns. + Vec dbeta_local, dgamma_local; + memset(&dgamma_local, 0, sizeof(dgamma_local)); + memset(&dbeta_local, 0, sizeof(dbeta_local)); + for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) { + index_t idx = row * Kernel_traits::COLS + col; + + Vec dbeta_part, dgamma_part; + dbeta_part.load_from(params.dbeta_part, idx); + dgamma_part.load_from(params.dgamma_part, idx); +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; + dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; + } } - for ( int cta_row = bidm * bdimm; - cta_row < params.rows; - cta_row += gdimm ) { - const int row = cta_row + warp_m; - compute_t mu = 0.f; - compute_t rs = 0.f; - if ( row < params.rows ) { - mu = static_cast(params.mu)[row]; - rs = static_cast(params.rs)[row]; - } + void *smem_gamma = smem_; + void *smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; - Cvec dy[LDGS]; - Cvec y[LDGS]; - compute_t mdy = 0.f; - compute_t mdyy = 0.f; - - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - Ivec x; - Ovec dz; - x.load_from_elts(params.x, row * params.cols + col, params.cols - col); - dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col); - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - const compute_t x_ij = x.data.elt[jt]; - const compute_t y_ij = rs * (x_ij - mu); - const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; - const compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift; - const compute_t dz_ij = dz.data.elt[jt]; - const compute_t dy_ij = g_ij * dz_ij; - - y[it].data.elt[jt] = y_ij; - dy[it].data.elt[jt] = dy_ij; - - mdy += dy_ij; - mdyy += dy_ij * y_ij; - - dz_sum[it].data.elt[jt] += dz_ij; - dzy_sum[it].data.elt[jt] += dz_ij * y_ij; - } - } + const int write_row = warp; + const int write_col = lane ^ write_row; + const int write_idx = write_row * THREADS_PER_WARP + write_col; - // Reduce over row - reduce_t result = reducer.allreduce({mdy, mdyy}, sum); - mdy = layer_norm::Get<0>::of(result) * rn; - mdyy = layer_norm::Get<1>::of(result) * rn; - - // Compute dx - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - Ivec dx; - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t dy_ij = dy[it].data.elt[jt]; - compute_t y_ij = y[it].data.elt[jt]; - dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy)); - } - dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); - } + dgamma_local.store_to(smem_gamma, write_idx); + dbeta_local.store_to(smem_beta, write_idx); + + __syncthreads(); + + // It would be probably safe to reuse the first row of smem_beta and smem_gamma + void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void *smem_beta_out = + &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; + + // More than one iter iff ROWS_PER_CTA < 32. + for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) { + const int read_row = lane; + const int read_col = w ^ read_row; + const int read_idx = read_row * THREADS_PER_WARP + read_col; + + memset(&dbeta_local, 0, sizeof(dbeta_local)); + memset(&dgamma_local, 0, sizeof(dgamma_local)); + + // Load beta and gamma transposed + if (read_row < Kernel_traits::ROWS_PER_CTA) { + dbeta_local.load_from(smem_beta, read_idx); + dgamma_local.load_from(smem_gamma, read_idx); + } + +// Call reducer on the loaded value(s) and convert. +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + compute_t b_i = dbeta_local.data.elt[it]; + compute_t g_i = dgamma_local.data.elt[it]; + b_i = reducer.allreduce(b_i, sum); + g_i = reducer.allreduce(g_i, sum); + + dgamma_local.data.elt[it] = g_i; + dbeta_local.data.elt[it] = b_i; + } + + // Leader stores the result at the current column. + if (lane == 0) { + dgamma_local.store_to(smem_gamma_out, w); + dbeta_local.store_to(smem_beta_out, w); + } } - if constexpr ( WARPS_M == 1 ) { - // Write out local weight grad contributions - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - dz_sum[it].store_to_elts(params.dbeta_part, - bidm * params.cols + col, - params.cols - col); - dzy_sum[it].store_to_elts(params.dgamma_part, - bidm * params.cols + col, - params.cols - col); - } - } else { - // Reduce weight grad contributions within CTA before writing - __shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP+1]; - - // Reduce dz - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - dz_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]); - } - __syncthreads(); - #pragma unroll - for ( int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; - it < LDGS && col < params.cols; - it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS ) { - #pragma unroll - for ( int kt = 0; kt < WARPS_M; kt++ ) { - if ( kt != warp_m ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - dz_sum[it].data.elt[jt] - += vecs_shared[it][kt][warp_n][lane].data.elt[jt]; - } - } - } - dz_sum[it].store_to_elts(params.dbeta_part, - bidm * params.cols + col, - params.cols - col); - } + // All writes done. + __syncthreads(); - // Reduce dzy - __syncthreads(); - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - if ( it != warp_m ) { - dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]); - } - } - __syncthreads(); - #pragma unroll - for ( int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; - it < LDGS && col < params.cols; - it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS ) { - #pragma unroll - for ( int kt = 0; kt < WARPS_M; kt++ ) { - if ( kt != warp_m ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - dzy_sum[it].data.elt[jt] - += vecs_shared[it][kt][warp_n][lane].data.elt[jt]; - } - } - } - dzy_sum[it].store_to_elts(params.dgamma_part, - bidm * params.cols + col, - params.cols - col); - } + // Pack and store: 2-wide stores with half the threads. + if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) { + using src_t = typename TypeToVec2::Type; + using dst_t = typename TypeToVec2::Type; + Vec dbeta_vec2, dgamma_vec2; + Vec dbeta_out2, dgamma_out2; + + dgamma_vec2.load_from(smem_gamma_out, lane); + dbeta_vec2.load_from(smem_beta_out, lane); +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); + dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); + } + dgamma_out2.store_to(params.dgamma, col_out); + dbeta_out2.store_to(params.dbeta, col_out); } + } } -template< - typename weight_t, - typename compute_t, - uint32_t WARPS_M, - uint32_t WARPS_N, - uint32_t BYTES_PER_LDG, - uint32_t THREADS_PER_WARP -> -__global__ __launch_bounds__(WARPS_M * WARPS_N * THREADS_PER_WARP) -void ln_bwd_finalize_general_kernel(layer_norm::BwdParams params) { - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; - using Wvec = Vec; - using Cvec = Vec; - - const int lane = threadIdx.x % THREADS_PER_WARP; - const int warp_m = threadIdx.y; - const int warp_n = threadIdx.x / THREADS_PER_WARP; - const int col = blockIdx.x * blockDim.x + threadIdx.x; - - // Load grad contributions and accumulate locally - Cvec dgamma, dbeta; - dgamma.clear(); - dbeta.clear(); - for ( int row = warp_m; - row < params.ctas_per_col && col < params.cols; - row += WARPS_M ) { - Cvec dgamma_part, dbeta_part; - dgamma_part.load_from_elts(params.dgamma_part, - row * params.cols + col, - params.cols - col); - dbeta_part.load_from_elts(params.dbeta_part, - row * params.cols + col, - params.cols - col); - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - dgamma.data.elt[jt] += dgamma_part.data.elt[jt]; - dbeta.data.elt[jt] += dbeta_part.data.elt[jt]; +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kernel( + layer_norm::BwdParams params) { + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using compute_t = typename Ktraits::compute_t; + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + + const index_t tidx = threadIdx.x; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t bdimm = WARPS_M; + const index_t bdimn = WARPS_N * THREADS_PER_WARP; + const index_t bidm = blockIdx.x / params.ctas_per_row; + const index_t bidn = blockIdx.x % params.ctas_per_row; + + const index_t gdimm = bdimm * params.ctas_per_col; + const index_t gdimn = bdimn * params.ctas_per_row; + const index_t gidm = bidm * bdimm + warp_m; + const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP + + lane); // Order threads by warp x cta x lane + + // Objects for weight grads + Cvec dzy_sum[LDGS]; + Cvec dz_sum[LDGS]; + memset(dzy_sum, 0, sizeof(dzy_sum)); + memset(dz_sum, 0, sizeof(dz_sum)); + + // Objects for stats reductions + using reduce_t = typename Ktraits::Reducer::Type; + using Reducer = DynamicReducer; + constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; + __shared__ char smem_[SMEM_BYTES]; + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); + Sum sum; + const compute_t rn = 1.f / static_cast(params.cols); + + // Load weights + Cvec gamma[LDGS]; +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + Wvec gamma_in; + gamma_in.load_from_elts(params.gamma, col, params.cols - col); + gamma_in.to(gamma[it]); + } + + for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { + const int row = cta_row + warp_m; + compute_t mu = 0.f; + compute_t rs = 0.f; + if (row < params.rows) { + mu = static_cast(params.mu)[row]; + rs = static_cast(params.rs)[row]; + } + + Cvec dy[LDGS]; + Cvec y[LDGS]; + compute_t mdy = 0.f; + compute_t mdyy = 0.f; + +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + Ivec x; + Ovec dz; + x.load_from_elts(params.x, row * params.cols + col, params.cols - col); + dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col); +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + const compute_t x_ij = x.data.elt[jt]; + const compute_t y_ij = rs * (x_ij - mu); + const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; + const compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift; + const compute_t dz_ij = dz.data.elt[jt]; + const compute_t dy_ij = g_ij * dz_ij; + + y[it].data.elt[jt] = y_ij; + dy[it].data.elt[jt] = dy_ij; + + mdy += dy_ij; + mdyy += dy_ij * y_ij; + + dz_sum[it].data.elt[jt] += dz_ij; + dzy_sum[it].data.elt[jt] += dz_ij * y_ij; + } + } + + // Reduce over row + reduce_t result = reducer.allreduce({mdy, mdyy}, sum); + mdy = layer_norm::Get<0>::of(result) * rn; + mdyy = layer_norm::Get<1>::of(result) * rn; + +// Compute dx +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + Ivec dx; +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t dy_ij = dy[it].data.elt[jt]; + compute_t y_ij = y[it].data.elt[jt]; + dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy)); + } + dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); + } + } + + if constexpr (WARPS_M == 1) { +// Write out local weight grad contributions +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + dz_sum[it].store_to_elts(params.dbeta_part, bidm * params.cols + col, params.cols - col); + dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col); + } + } else { + // Reduce weight grad contributions within CTA before writing + __shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1]; + +// Reduce dz +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + dz_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]); + } + __syncthreads(); +#pragma unroll + for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols; + it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) { +#pragma unroll + for (int kt = 0; kt < WARPS_M; kt++) { + if (kt != warp_m) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + dz_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt]; + } } + } + dz_sum[it].store_to_elts(params.dbeta_part, bidm * params.cols + col, params.cols - col); } - // Reduce dgamma within CTA - __shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP+1]; - dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]); - #pragma unroll - for ( int nrows = WARPS_M / 2; nrows > 0; nrows /= 2 ) { - __syncthreads(); - if ( warp_m < nrows ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - vecs_shared[warp_m][warp_n][lane].data.elt[jt] - += vecs_shared[warp_m+nrows][warp_n][lane].data.elt[jt]; - } + // Reduce dzy + __syncthreads(); +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + if (it != warp_m) { + dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]); + } + } + __syncthreads(); +#pragma unroll + for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols; + it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) { +#pragma unroll + for (int kt = 0; kt < WARPS_M; kt++) { + if (kt != warp_m) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt]; + } } + } + dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col); } - if ( warp_m == 0 && col < params.cols ) { - Wvec dgamma_out; - vecs_shared[warp_m][warp_n][lane].to(dgamma_out); - dgamma_out.store_to_elts(params.dgamma, col, params.cols - col); + } +} + +template +__global__ +__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_general_kernel( + layer_norm::BwdParams params) { + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; + using Wvec = Vec; + using Cvec = Vec; + + const int lane = threadIdx.x % THREADS_PER_WARP; + const int warp_m = threadIdx.y; + const int warp_n = threadIdx.x / THREADS_PER_WARP; + const int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Load grad contributions and accumulate locally + Cvec dgamma, dbeta; + dgamma.clear(); + dbeta.clear(); + for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) { + Cvec dgamma_part, dbeta_part; + dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col); + dbeta_part.load_from_elts(params.dbeta_part, row * params.cols + col, params.cols - col); +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + dgamma.data.elt[jt] += dgamma_part.data.elt[jt]; + dbeta.data.elt[jt] += dbeta_part.data.elt[jt]; } + } - // Reduce dgamma within CTA + // Reduce dgamma within CTA + __shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1]; + dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]); +#pragma unroll + for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) { __syncthreads(); - dbeta.store_to(&vecs_shared[warp_m][warp_n][lane]); - #pragma unroll - for ( int nrows = WARPS_M / 2; nrows > 0; nrows /= 2 ) { - __syncthreads(); - if ( warp_m < nrows ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - vecs_shared[warp_m][warp_n][lane].data.elt[jt] - += vecs_shared[warp_m+nrows][warp_n][lane].data.elt[jt]; - } - } + if (warp_m < nrows) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + vecs_shared[warp_m][warp_n][lane].data.elt[jt] += + vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt]; + } } - if ( warp_m == 0 && col < params.cols ) { - Wvec dbeta_out; - vecs_shared[warp_m][warp_n][lane].to(dbeta_out); - dbeta_out.store_to_elts(params.dbeta, col, params.cols - col); + } + if (warp_m == 0 && col < params.cols) { + Wvec dgamma_out; + vecs_shared[warp_m][warp_n][lane].to(dgamma_out); + dgamma_out.store_to_elts(params.dgamma, col, params.cols - col); + } + + // Reduce dgamma within CTA + __syncthreads(); + dbeta.store_to(&vecs_shared[warp_m][warp_n][lane]); +#pragma unroll + for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) { + __syncthreads(); + if (warp_m < nrows) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + vecs_shared[warp_m][warp_n][lane].data.elt[jt] += + vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt]; + } } + } + if (warp_m == 0 && col < params.cols) { + Wvec dbeta_out; + vecs_shared[warp_m][warp_n][lane].to(dbeta_out); + dbeta_out.store_to_elts(params.dbeta, col, params.cols - col); + } } } // namespace layer_norm diff --git a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu index aaf5b69010b1b9879ae05c7b86db230576540068..17f12569104f5b6d2ec4619e2e152708016f9127 100644 --- a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -5,233 +5,154 @@ ************************************************************************/ #include "ln.h" -#include "ln_kernel_traits.h" #include "ln_bwd_kernels.cuh" +#include "ln_kernel_traits.h" using namespace transformer_engine::layer_norm; -template< - typename weight_t, - typename input_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG_MAIN, - int BYTES_PER_LDG_FINAL -> -void launch_tuned_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_tuned_kernel; - - if ( configure_params ) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = launch_params.multiprocessorCount - * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::reduce_t) - * 2; - } - return; - } - - if ( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES)); +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_bwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::reduce_t) * 2; } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if ( ctas_per_row == 1 ) { - kernel<<>> - (launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), - grid, - block, - reinterpret_cast(¶ms_), - Kernel_traits::SMEM_BYTES, stream); - } - - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - - auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel; - kernel_f<<>> - (launch_params.params); -} - -template< - typename weight_t, - typename input_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG_MAIN, - int BYTES_PER_LDG_FINAL -> -void launch_general_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Instantiate kernel - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_general_kernel; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if ( configure_params ) { - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), - max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = (ctas_per_col - * WARPS_M - * ctas_per_row - * sizeof(typename Kernel_traits::reduce_t) - * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; + return; + } + + if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); - if ( ctas_per_row == 1 ) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), - grid, - block, - reinterpret_cast(¶ms_), - 0, - stream); - } + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, + stream); + } + + using Kernel_traits_f = layer_norm::Kernel_traits_finalize; + + auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel; + kernel_f<<>>( + launch_params.params); +} - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = (Kernel_traits::THREADS_PER_WARP - * WARPS_N_FINAL - * BYTES_PER_LDG_FINAL - / sizeof(compute_t)); - auto kernel_final = &ln_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Instantiate kernel + using Kernel_traits = Kernel_traits; + auto kernel = &ln_bwd_general_kernel; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, + Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_size = 2 * ctas_per_col; + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); + } + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } + + // Launch finalization kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); + auto kernel_final = + &ln_bwd_finalize_general_kernel; + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + kernel_final<<>>(launch_params.params); } //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_BWD_TUNED_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \ - BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams \ - &launch_params, \ - const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_BWD_GENERAL_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \ - BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams \ - &launch_params, \ - const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) +#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ + WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_tuned_(launch_params, \ + configure_params); \ + } \ + static BwdTunedRegistrar \ + reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) + +#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ + BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_general_(launch_params, configure_params); \ + } \ + static BwdGeneralRegistrar \ + reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -252,9 +173,9 @@ REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); @@ -263,11 +184,11 @@ REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); @@ -318,16 +239,16 @@ REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); +REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); +REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); @@ -336,9 +257,9 @@ REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); diff --git a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu old mode 100755 new mode 100644 index da2919bbfc5c67b0bc3da8bcada5a798dcae2dda..0c85f4aeb7269cd23ff3dd111584614d4aebd058 --- a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu @@ -5,176 +5,128 @@ ************************************************************************/ #include "ln.h" -#include "ln_kernel_traits.h" #include "ln_fwd_kernels.cuh" +#include "ln_kernel_traits.h" using namespace transformer_engine::layer_norm; -template< - typename weight_t, - typename input_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG -> -void launch_tuned_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_tuned_kernel; - - if ( configure_params ) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = launch_params.multiprocessorCount * - ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::Stats::stats_t) - * 2; - } - return; - } - - if ( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if ( ctas_per_row == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_fwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::Stats::stats_t) * 2; } -} - -template< - typename weight_t, - typename input_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG -> -void launch_general_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_general_kernel; - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if ( configure_params ) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), - max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = (ctas_per_col - * WARPS_M - * ctas_per_row - * sizeof(compute_t) - * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; + return; + } + + if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); - if ( ctas_per_row == 1 ) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), - grid, - block, - reinterpret_cast(¶ms_), - 0, - stream); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) + Kernel_traits::SMEM_BYTES_FWD, stream); + } +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_fwd_general_kernel; + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_size = 2 * ctas_per_col; + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); } + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ - CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_tuned_( \ - launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_general_( \ - launch_params, configure_params); \ - } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) +#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ + WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_tuned_(launch_params, configure_params); \ + } \ + static FwdTunedRegistrar \ + reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) + +#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ + BYTES_PER_LDG) \ + void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_general_(launch_params, configure_params); \ + } \ + static FwdGeneralRegistrar \ + reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -187,21 +139,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); @@ -213,21 +165,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); @@ -239,21 +191,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); @@ -295,11 +247,11 @@ REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); @@ -337,17 +289,17 @@ REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); @@ -373,17 +325,17 @@ REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4); + +REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); +REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); diff --git a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh index 5b3943fc1a2e1ce0d31fe32ea45da334e7d23906..9fe4c16373debb9cdc103f258ab57d5facab853d 100644 --- a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh @@ -9,306 +9,294 @@ #include #include -#include "ln.h" + #include "../utils.cuh" +#include "ln.h" namespace transformer_engine { namespace layer_norm { using namespace transformer_engine; -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_fwd_tuned_kernel(FwdParams params) { - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - - using Stats = typename Ktraits::Stats; - using stats_t = typename Stats::stats_t; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t r = bidm * ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); - - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); - - Wvec gamma[LDGS]; - Wvec beta[LDGS]; - index_t idx = c; - #pragma unroll - for ( int it = 0; it < LDGS; ++it ) { - gamma[it].load_from(params.gamma, idx); - beta[it].load_from(params.beta, idx); - idx += VEC_COLS_PER_LDG; +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(FwdParams params) { + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); + + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + Wvec gamma[LDGS]; + Wvec beta[LDGS]; + index_t idx = c; +#pragma unroll + for (int it = 0; it < LDGS; ++it) { + gamma[it].load_from(params.gamma, idx); + beta[it].load_from(params.beta, idx); + idx += VEC_COLS_PER_LDG; + } + + constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); + + compute_t scale = 1.f; + if (params.fp8_out) { + scale = *reinterpret_cast(params.scale); + } + compute_t amax = 0; + + for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { + Ivec x[LDGS]; + index_t idx = row * Ktraits::VEC_COLS + c; + compute_t xf[LDGS * NUM_ELTS]; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + x[it].load_from(params.x, idx); +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t x_ij = compute_t(x[it].data.elt[jt]); + xf[it * NUM_ELTS + jt] = x_ij; + } + idx += VEC_COLS_PER_LDG; } - constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); + stats_t s = stats.compute(xf, rn); + + compute_t mu = layer_norm::Get<0>::of(s); + compute_t m2 = layer_norm::Get<1>::of(s); - compute_t scale = 1.f; - if (params.fp8_out) { - scale = *reinterpret_cast(params.scale); + if (bidn == 0 && warp_n == 0 && lane == 0) { + mu_ptr[row] = mu; } - compute_t amax = 0; - - for ( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - Ivec x[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; - compute_t xf[LDGS * NUM_ELTS]; - #pragma unroll - for ( int it = 0; it < LDGS; it++ ) { - x[it].load_from(params.x, idx); - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t x_ij = compute_t(x[it].data.elt[jt]); - xf[it * NUM_ELTS + jt] = x_ij; - } - idx += VEC_COLS_PER_LDG; - } - stats_t s = stats.compute(xf, rn); + compute_t rs = rsqrtf(rn * m2 + params.epsilon); - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); + if (bidn == 0 && warp_n == 0 && lane == 0) { + rs_ptr[row] = rs; + } - if ( bidn == 0 && warp_n == 0 && lane == 0 ) { - mu_ptr[row] = mu; + Ovec z[LDGS]; + idx = row * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu); + compute_t g_ij = gamma[it].data.elt[jt]; + if (params.zero_centered_gamma) { + g_ij += 1; } + compute_t b_ij = beta[it].data.elt[jt]; + compute_t temp_output = g_ij * y_ij + b_ij; - compute_t rs = rsqrtf(rn * m2 + params.epsilon); - - if ( bidn == 0 && warp_n == 0 && lane == 0 ) { - rs_ptr[row] = rs; + if (params.fp8_out) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(temp_output)); + temp_output = temp_output * scale; } - Ovec z[LDGS]; - idx = row * Ktraits::VEC_COLS + c; - #pragma unroll - for ( int it = 0; it < LDGS; it++ ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu); - compute_t g_ij = gamma[it].data.elt[jt]; - if (params.zero_centered_gamma) { - g_ij += 1; - } - compute_t b_ij = beta[it].data.elt[jt]; - compute_t temp_output = g_ij * y_ij + b_ij; - - if (params.fp8_out) { - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(temp_output)); - temp_output = temp_output * scale; - } - - z[it].data.elt[jt] = output_t(temp_output); - } - z[it].store_to(params.z, idx); - idx += VEC_COLS_PER_LDG; - } + z[it].data.elt[jt] = output_t(temp_output); + } + z[it].store_to(params.z, idx); + idx += VEC_COLS_PER_LDG; } - if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0 && threadIdx.y == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + } + if (params.fp8_out) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0 && threadIdx.y == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } } -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_fwd_general_kernel(FwdParams params) { - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - - const index_t tidx = threadIdx.x; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t bdimm = WARPS_M; - const index_t bdimn = WARPS_N * THREADS_PER_WARP; - const index_t bidm = blockIdx.x / params.ctas_per_row; - const index_t bidn = blockIdx.x % params.ctas_per_row; - - const index_t gdimm = bdimm * params.ctas_per_col; - const index_t gdimn = bdimn * params.ctas_per_row; - const index_t gidm = bidm * bdimm + warp_m; - const index_t gidn = (bidn * THREADS_PER_WARP - + warp_n * params.ctas_per_row * THREADS_PER_WARP - + lane); // Order threads by warp x cta x lane - - // Objects for stats reductions - using Reducer = DynamicReducer; - constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; - __shared__ char smem_[SMEM_BYTES]; - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); - Sum sum; - const compute_t rn = 1.f / static_cast(params.cols); - - // Load weights - Cvec gamma[LDGS]; - Cvec beta[LDGS]; - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && col < params.cols; - ++it, col += gdimn * NUM_ELTS ) { - Wvec gamma_in, beta_in; - gamma_in.load_from_elts(params.gamma, col, params.cols - col); - beta_in.load_from_elts(params.beta, col, params.cols - col); - gamma_in.to(gamma[it]); - beta_in.to(beta[it]); +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kernel( + FwdParams params) { + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + + const index_t tidx = threadIdx.x; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t bdimm = WARPS_M; + const index_t bdimn = WARPS_N * THREADS_PER_WARP; + const index_t bidm = blockIdx.x / params.ctas_per_row; + const index_t bidn = blockIdx.x % params.ctas_per_row; + + const index_t gdimm = bdimm * params.ctas_per_col; + const index_t gdimn = bdimn * params.ctas_per_row; + const index_t gidm = bidm * bdimm + warp_m; + const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP + + lane); // Order threads by warp x cta x lane + + // Objects for stats reductions + using Reducer = DynamicReducer; + constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; + __shared__ char smem_[SMEM_BYTES]; + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); + Sum sum; + const compute_t rn = 1.f / static_cast(params.cols); + + // Load weights + Cvec gamma[LDGS]; + Cvec beta[LDGS]; +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + ++it, col += gdimn * NUM_ELTS) { + Wvec gamma_in, beta_in; + gamma_in.load_from_elts(params.gamma, col, params.cols - col); + beta_in.load_from_elts(params.beta, col, params.cols - col); + gamma_in.to(gamma[it]); + beta_in.to(beta[it]); + } + + // fp8 factors + compute_t scale; + if (params.fp8_out) { + scale = *reinterpret_cast(params.scale); + } + compute_t amax = 0; + + for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { + const int row = cta_row + warp_m; + + // Load input + Cvec x[LDGS]; +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + Ivec x_in; + x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col); + x_in.to(x[it]); } - // fp8 factors - compute_t scale; - if ( params.fp8_out ) { - scale = *reinterpret_cast(params.scale); + // Compute mean + compute_t mu = 0.f; +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + mu += x[it].data.elt[jt]; + } } - compute_t amax = 0; - - for ( int cta_row = bidm * bdimm; - cta_row < params.rows; - cta_row += gdimm ) { - const int row = cta_row + warp_m; - - // Load input - Cvec x[LDGS]; - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - Ivec x_in; - x_in.load_from_elts(params.x, - row * params.cols + col, - params.cols - col); - x_in.to(x[it]); + mu = reducer.allreduce(mu, sum) * rn; + + // Compute variance + compute_t sqsigma = 0.f; +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + if (col + jt < params.cols) { + compute_t diff = x[it].data.elt[jt] - mu; + sqsigma += diff * diff; } + } + } + sqsigma = reducer.allreduce(sqsigma, sum) * rn; + compute_t rs = rsqrtf(sqsigma + params.epsilon); + + // Write statistics + if (gidn == 0 && row < params.rows) { + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + mu_ptr[row] = mu; + rs_ptr[row] = rs; + } - // Compute mean - compute_t mu = 0.f; - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - mu += x[it].data.elt[jt]; - } - } - mu = reducer.allreduce(mu, sum) * rn; - - // Compute variance - compute_t sqsigma = 0.f; - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - if ( col + jt < params.cols ) { - compute_t diff = x[it].data.elt[jt] - mu; - sqsigma += diff * diff; - } - } +// Compute output +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + // Compute output values + Cvec z; +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t y_ij = rs * (x[it].data.elt[jt] - mu); + compute_t g_ij = gamma[it].data.elt[jt]; + if (params.zero_centered_gamma) { + g_ij += 1; } - sqsigma = reducer.allreduce(sqsigma, sum) * rn; - compute_t rs = rsqrtf(sqsigma + params.epsilon); - - // Write statistics - if ( gidn == 0 && row < params.rows ) { - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); - mu_ptr[row] = mu; - rs_ptr[row] = rs; + compute_t b_ij = beta[it].data.elt[jt]; + z.data.elt[jt] = g_ij * y_ij + b_ij; + } + + // Apply fp8 factors + if (params.fp8_out) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + if (col + jt < params.cols) { + compute_t z_ij = z.data.elt[jt]; + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(z_ij)); + z.data.elt[jt] = z_ij * scale; + } } + } - // Compute output - #pragma unroll - for ( int it = 0, col = gidn * NUM_ELTS; - it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS ) { - // Compute output values - Cvec z; - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t y_ij = rs * (x[it].data.elt[jt] - mu); - compute_t g_ij = gamma[it].data.elt[jt]; - if (params.zero_centered_gamma) { - g_ij += 1; - } - compute_t b_ij = beta[it].data.elt[jt]; - z.data.elt[jt] = g_ij * y_ij + b_ij; - } - - // Apply fp8 factors - if ( params.fp8_out ) { - #pragma unroll - for ( int jt = 0; jt < NUM_ELTS; jt++ ) { - if ( col + jt < params.cols ) { - compute_t z_ij = z.data.elt[jt]; - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); - z.data.elt[jt] = z_ij * scale; - } - } - } - - // Store output - Ovec z_out; - z.to(z_out); - z_out.store_to_elts(params.z, - row * params.cols + col, - params.cols - col); - } + // Store output + Ovec z_out; + z.to(z_out); + z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col); } - - // Finalize fp8 factors - if ( params.fp8_out ) { - amax = reduce_max(amax, warp); - if ( threadIdx.x == 0 ) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + } + + // Finalize fp8 factors + if (params.fp8_out) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } } } // namespace layer_norm diff --git a/transformer_engine/common/layer_norm/ln_kernel_traits.h b/transformer_engine/common/layer_norm/ln_kernel_traits.h index 3aa985e6c60e45d28f60a24f4728b5fba553a23d..a72726c325b0a50f226c580474599a6c6bc37649 100644 --- a/transformer_engine/common/layer_norm/ln_kernel_traits.h +++ b/transformer_engine/common/layer_norm/ln_kernel_traits.h @@ -14,154 +14,119 @@ namespace transformer_engine { namespace layer_norm { -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t THREADS_PER_CTA_ -> +template struct Kernel_traits_base { - using weight_t = weight_t_; - using input_t = input_t_; - using output_t = output_t_; - using compute_t = compute_t_; - using index_t = index_t_; - - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; - enum { THREADS_PER_WARP = 32 }; + using weight_t = weight_t_; + using input_t = input_t_; + using output_t = output_t_; + using compute_t = compute_t_; + using index_t = index_t_; + + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; + enum { THREADS_PER_WARP = 32 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t THREADS_PER_CTA_, - uint32_t BYTES_PER_LDG_, - typename Base = Kernel_traits_base -> +template > struct Kernel_traits_finalize : public Base { - enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; - static_assert(static_cast(ROWS_PER_CTA) <= static_cast(Base::THREADS_PER_WARP)); - // Bytes per global load from the input. - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - // Number of elements fetched by a global load. - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; - // Bytes per global store of the weights. - enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; - static_assert(sizeof(BYTES_PER_LDG) == 4, - "Conflict-free smem transpose only implemented for 4B compute type!"); - static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, - "We assume one warp per row!"); - // The total number of BYTES_PER_LDG-wide words in a hidden vector. - enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; - static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); - - // Shared memory size to transpose the CTA result. - enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; - // Shared memory size to coalsece the CTA result. - enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; - // Shared memory requirement per CTA. - enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; - - // The type of the reducer. - using Reducer = transformer_engine::Reducer; - - // Condition for the whole CTA to participate in syncthreads. - static_assert(COLS % Base::THREADS_PER_WARP == 0); - enum { CTAS = COLS / Base::THREADS_PER_WARP }; + enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; + static_assert(static_cast(ROWS_PER_CTA) <= static_cast(Base::THREADS_PER_WARP)); + // Bytes per global load from the input. + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + // Number of elements fetched by a global load. + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; + // Bytes per global store of the weights. + enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; + static_assert(sizeof(BYTES_PER_LDG) == 4, + "Conflict-free smem transpose only implemented for 4B compute type!"); + static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, + "We assume one warp per row!"); + // The total number of BYTES_PER_LDG-wide words in a hidden vector. + enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; + static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); + + // Shared memory size to transpose the CTA result. + enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; + // Shared memory size to coalsece the CTA result. + enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; + // Shared memory requirement per CTA. + enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; + + // The type of the reducer. + using Reducer = transformer_engine::Reducer; + + // Condition for the whole CTA to participate in syncthreads. + static_assert(COLS % Base::THREADS_PER_WARP == 0); + enum { CTAS = COLS / Base::THREADS_PER_WARP }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - typename weight_t_, - typename input_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t HIDDEN_SIZE_, - uint32_t CTAS_PER_ROW_, - uint32_t WARPS_M_, - uint32_t WARPS_N_, - uint32_t BYTES_PER_LDG_ = 16, - typename Base = Kernel_traits_base< - HIDDEN_SIZE_, - weight_t_, - input_t_, - output_t_, - compute_t_, - index_t_, - WARPS_M_*WARPS_N_*THREADS_PER_WARP - > -> +template > struct Kernel_traits : public Base { - using input_t = typename Base::input_t; - using weight_t = typename Base::weight_t; - using compute_t = typename Base::compute_t; - using output_t = typename Base::output_t; - using index_t = typename Base::index_t; - - enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; - enum { WARPS_M = WARPS_M_ }; - enum { WARPS_N = WARPS_N_ }; - enum { COLS = HIDDEN_SIZE_ }; - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; - - enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; - enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; - enum { ROWS_PER_CTA = WARPS_M }; - - enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; - enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; - // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed - enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; - static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); - - using reduce_t = typename transformer_engine::TypeToVec2::Type; - using Reducer = transformer_engine::Reducer; - - enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; - enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; - - using Ivec = transformer_engine::Vec; - using Ovec = transformer_engine::Vec; - using Wvec = transformer_engine::Vec; - using Cvec = transformer_engine::Vec; - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; - - // Assume that each thread can handle the same number of elements - // in the output and weights as in the input. - static_assert(sizeof(input_t) >= sizeof(output_t)); - static_assert(sizeof(input_t) >= sizeof(weight_t)); - // The number of columns fetched per load from input: one per thread. - enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; - // The total number of vectorized loads/stores per hidden vector. - enum { VEC_COLS = COLS / ELTS_PER_LDG }; - // The number of loads per thread for the input. - enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; - static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); - // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); - - using Stats = transformer_engine::Stats; - enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; + using input_t = typename Base::input_t; + using weight_t = typename Base::weight_t; + using compute_t = typename Base::compute_t; + using output_t = typename Base::output_t; + using index_t = typename Base::index_t; + + enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; + enum { WARPS_M = WARPS_M_ }; + enum { WARPS_N = WARPS_N_ }; + enum { COLS = HIDDEN_SIZE_ }; + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; + + enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; + enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; + enum { ROWS_PER_CTA = WARPS_M }; + + enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; + enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; + // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed + enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA* COLS * sizeof(compute_t) }; + static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); + + using reduce_t = typename transformer_engine::TypeToVec2::Type; + using Reducer = transformer_engine::Reducer; + + enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; + enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; + + using Ivec = transformer_engine::Vec; + using Ovec = transformer_engine::Vec; + using Wvec = transformer_engine::Vec; + using Cvec = transformer_engine::Vec; + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; + + // Assume that each thread can handle the same number of elements + // in the output and weights as in the input. + static_assert(sizeof(input_t) >= sizeof(output_t)); + static_assert(sizeof(input_t) >= sizeof(weight_t)); + // The number of columns fetched per load from input: one per thread. + enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; + // The total number of vectorized loads/stores per hidden vector. + enum { VEC_COLS = COLS / ELTS_PER_LDG }; + // The number of loads per thread for the input. + enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; + static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); + // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); + + using Stats = transformer_engine::Stats; + enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/transformer_engine/common/nvtx.h b/transformer_engine/common/nvtx.h index 1a23f1b08a4a317379521aeef3ab7e211ed4e8f3..191f3b06faccd535d875f7476acabe798f7acabc 100644 --- a/transformer_engine/common/nvtx.h +++ b/transformer_engine/common/nvtx.h @@ -7,19 +7,16 @@ #ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_ #define TRANSFORMER_ENGINE_COMMON_NVTX_H_ -#include #include +#include + namespace transformer_engine::nvtx { struct NVTXWrapper { - explicit NVTXWrapper(const std::string &name) { - nvtxRangePush(name.c_str()); - } + explicit NVTXWrapper(const std::string &name) { nvtxRangePush(name.c_str()); } - ~NVTXWrapper() { - nvtxRangePop(); - } + ~NVTXWrapper() { nvtxRangePop(); } }; } // namespace transformer_engine::nvtx diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 2e232f50e2c81c50c4e354ed1240bc5c9566a134..fcace6ac3ddf2bac062d41355239e505ecdb6277 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -7,12 +7,12 @@ #include #include -#include #include +#include #include "../common.h" -#include "../util/logging.h" #include "../util/cuda_runtime.h" +#include "../util/logging.h" namespace transformer_engine { namespace delayed_scaling_recipe { @@ -24,18 +24,19 @@ enum class AmaxComputeAlgo { INVALID, MOST_RECENT, MAX }; const char* dtype_name(DType dtype) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, Type, - return TypeInfo::name; - ); // NOLINT(*) + return TypeInfo::name;); // NOLINT(*) return ""; } // Maximum representable value of an FP8 dtype inline float fp8_dtype_max(DType dtype) { switch (dtype) { - case DType::kFloat8E4M3: return 448; - case DType::kFloat8E5M2: return 57344; - default: - NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype)); + case DType::kFloat8E4M3: + return 448; + case DType::kFloat8E5M2: + return 57344; + default: + NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype)); } return 0; } @@ -58,12 +59,12 @@ struct OtherParams { #if CUDART_VERSION >= 12010 constexpr size_t max_constant_memory_per_kernel = 32768; -constexpr size_t AMAX_PARAMS_LIMIT = ( - max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +constexpr size_t AMAX_PARAMS_LIMIT = + (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); #else constexpr size_t max_constant_memory_per_kernel = 4096; -constexpr size_t AMAX_PARAMS_LIMIT = ( - max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +constexpr size_t AMAX_PARAMS_LIMIT = + (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); #endif struct AmaxParams { @@ -82,17 +83,10 @@ constexpr size_t bsize = 256; * Grid dims: num_scales x 1 x 1 */ __global__ void __launch_bounds__(bsize) -kernel(const float* amax_history_ptr, - const float* scale_ptr, - const float* scale_inv_ptr, - const unsigned char* scale_inv_mask_ptr, - float* updated_amax_history_ptr, - float* updated_scale_ptr, - float* updated_scale_inv_ptr, - size_t amax_history_length, - size_t amax_history_stride, - AmaxComputeAlgo amax_compute_algo, - float scaled_max) { + kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr, + const unsigned char* scale_inv_mask_ptr, float* updated_amax_history_ptr, + float* updated_scale_ptr, float* updated_scale_inv_ptr, size_t amax_history_length, + size_t amax_history_stride, AmaxComputeAlgo amax_compute_algo, float scaled_max) { const size_t tid = threadIdx.x; const size_t bid = blockIdx.x; @@ -109,22 +103,21 @@ kernel(const float* amax_history_ptr, const size_t i = off + tid; float a = 0; if (i < length) { - a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; + a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax; amax = fmaxf(amax, a); } __syncthreads(); // In case roll is in-place if (i < length) { - updated_amax_history[i*stride] = (i > 0) ? a : 0; + updated_amax_history[i * stride] = (i > 0) ? a : 0; } } // Compute amax to use for scaling factor switch (amax_compute_algo) { - case AmaxComputeAlgo::MOST_RECENT: - amax = last_amax; - break; - case AmaxComputeAlgo::MAX: - { + case AmaxComputeAlgo::MOST_RECENT: + amax = last_amax; + break; + case AmaxComputeAlgo::MAX: { __shared__ float shared_amax[bsize]; shared_amax[tid] = amax; __syncthreads(); @@ -136,10 +129,9 @@ kernel(const float* amax_history_ptr, __syncthreads(); } amax = shared_amax[tid]; - } - break; - default: - amax = 0; + } break; + default: + amax = 0; } } @@ -157,7 +149,7 @@ kernel(const float* amax_history_ptr, // amax won't get mapped to the FP8 max representable, but rather // something below that, but this is the best thing we can do. if (isinf(scale)) { - scale = std::numeric_limits::max(); + scale = std::numeric_limits::max(); } updated_scale_ptr[bid] = scale; @@ -179,12 +171,8 @@ kernel(const float* amax_history_ptr, * Grid dims: num_tensors x 1 x 1 */ __global__ void __launch_bounds__(bsize) -kernel_bulk( - float* amax_reduction_buffer, - AmaxParams p, - size_t amax_history_length, - AmaxComputeAlgo amax_compute_algo, - float scaled_max) { + kernel_bulk(float* amax_reduction_buffer, AmaxParams p, size_t amax_history_length, + AmaxComputeAlgo amax_compute_algo, float scaled_max) { const size_t bid = blockIdx.x; const size_t tid = threadIdx.x; const int num_scale = p.param[bid].num_scale; @@ -201,32 +189,32 @@ kernel_bulk( // Roll amax history const auto& length = amax_history_length; const auto& stride = p.param[bid].num_scale; - auto* amax_history = p.param[bid].amax_history+count; - const auto last_amax = ((amax_reduction_buffer != nullptr) - && (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ? - amax_reduction_buffer[offset_in_buffer+count] : amax_history[0]; + auto* amax_history = p.param[bid].amax_history + count; + const auto last_amax = ((amax_reduction_buffer != nullptr) && + (amax_reduction_buffer[offset_in_buffer + count] != 0.0f)) + ? amax_reduction_buffer[offset_in_buffer + count] + : amax_history[0]; if (last_amax != 0.0f) { for (size_t off = 0; off < length; off += bsize) { const size_t i = off + tid; float a = 0; if (i < length) { - a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; + a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax; amax = fmaxf(amax, a); } __syncthreads(); // Inplace roll if (i < length) { - amax_history[i*stride] = (i > 0) ? a : 0; + amax_history[i * stride] = (i > 0) ? a : 0; } } } // Compute amax to use for scaling factor switch (amax_compute_algo) { - case AmaxComputeAlgo::MOST_RECENT: - amax = last_amax; - break; - case AmaxComputeAlgo::MAX: - { + case AmaxComputeAlgo::MOST_RECENT: + amax = last_amax; + break; + case AmaxComputeAlgo::MAX: { __shared__ float shared_amax[bsize]; shared_amax[tid] = amax; __syncthreads(); @@ -238,10 +226,9 @@ kernel_bulk( __syncthreads(); } amax = shared_amax[tid]; - } - break; - default: - amax = 0; + } break; + default: + amax = 0; } } @@ -269,7 +256,7 @@ kernel_bulk( // amax won't get mapped to the FP8 max representable, but rather // something below that, but this is the best thing we can do. if (isinf(scale)) { - scale = std::numeric_limits::max(); + scale = std::numeric_limits::max(); } p.param[bid].scale[count] = scale; p.param[bid].scale_inv[count] = 1 / scale; @@ -281,24 +268,17 @@ kernel_bulk( } // namespace - -void amax_and_scale_update(const Tensor &amax_history, - const Tensor &scale, - const Tensor &scale_inv, - const Tensor &scale_inv_mask, - Tensor *updated_amax_history_, - Tensor *updated_scale_, - Tensor *updated_scale_inv_, - const std::string &amax_compute_algo, - DType fp8_dtype, - float margin, +void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, const Tensor& scale_inv, + const Tensor& scale_inv_mask, Tensor* updated_amax_history_, + Tensor* updated_scale_, Tensor* updated_scale_inv_, + const std::string& amax_compute_algo, DType fp8_dtype, float margin, cudaStream_t stream) { auto& updated_amax_history = *updated_amax_history_; auto& updated_scale = *updated_scale_; auto& updated_scale_inv = *updated_scale_inv_; // Number of elements in tensor - auto numel = [] (const Tensor &tensor) -> size_t { + auto numel = [](const Tensor& tensor) -> size_t { size_t acc = 1; for (const auto& dim : tensor.data.shape) { acc *= dim; @@ -307,48 +287,40 @@ void amax_and_scale_update(const Tensor &amax_history, }; // Check tensors - NVTE_CHECK(amax_history.data.shape.size() == 2, - "Found ", amax_history.data.shape.size(), " dims"); + NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(), + " dims"); const size_t amax_history_length = amax_history.data.shape[0]; const size_t num_scales = amax_history.data.shape[1]; - NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, - "Found ", dtype_name(amax_history.data.dtype), "."); - NVTE_CHECK(numel(scale) == num_scales, - "Expected ", num_scales, " elements, ", - "but found ", numel(scale), "."); - NVTE_CHECK(scale.data.dtype == DType::kFloat32, - "Found ", dtype_name(scale.data.dtype), "."); + NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ", + dtype_name(amax_history.data.dtype), "."); + NVTE_CHECK(numel(scale) == num_scales, "Expected ", num_scales, " elements, ", "but found ", + numel(scale), "."); + NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), "."); if (scale_inv_mask.data.dptr != nullptr) { - NVTE_CHECK(numel(scale_inv) == num_scales, - "Expected ", num_scales, " elements, ", - "but found ", numel(scale_inv), "."); + NVTE_CHECK(numel(scale_inv) == num_scales, "Expected ", num_scales, " elements, ", "but found ", + numel(scale_inv), "."); NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32); - NVTE_CHECK(numel(scale_inv_mask) == num_scales, - "Expected ", num_scales, " elements, ", + NVTE_CHECK(numel(scale_inv_mask) == num_scales, "Expected ", num_scales, " elements, ", "but found ", numel(scale_inv_mask), "."); - NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, - "Found ", dtype_name(scale_inv_mask.data.dtype), "."); + NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ", + dtype_name(scale_inv_mask.data.dtype), "."); } - NVTE_CHECK(updated_amax_history.data.shape.size() == 2, - "Found ", updated_amax_history.data.shape.size(), " dims."); - NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, - "Expected ", amax_history_length, ", ", - "but found ", updated_amax_history.data.shape[0]); - NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, - "Expected ", num_scales, ", ", + NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ", + updated_amax_history.data.shape.size(), " dims."); + NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ", + amax_history_length, ", ", "but found ", updated_amax_history.data.shape[0]); + NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, "Expected ", num_scales, ", ", "but found ", updated_amax_history.data.shape[1]); - NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, - "Got ", dtype_name(updated_amax_history.data.dtype), "."); - NVTE_CHECK(numel(updated_scale) == num_scales, - "Expected ", num_scales, " elements, ", + NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ", + dtype_name(updated_amax_history.data.dtype), "."); + NVTE_CHECK(numel(updated_scale) == num_scales, "Expected ", num_scales, " elements, ", "but found ", numel(updated_scale), "."); - NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, - "Got ", dtype_name(updated_scale.data.dtype), "."); - NVTE_CHECK(numel(updated_scale_inv) == num_scales, - "Expected ", num_scales, " elements, ", + NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ", + dtype_name(updated_scale.data.dtype), "."); + NVTE_CHECK(numel(updated_scale_inv) == num_scales, "Expected ", num_scales, " elements, ", "but found ", numel(updated_scale_inv), "."); - NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, - "Got ", dtype_name(updated_scale_inv.data.dtype), "."); + NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, "Got ", + dtype_name(updated_scale_inv.data.dtype), "."); // amax value to use for updating scaling factor AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; @@ -366,31 +338,23 @@ void amax_and_scale_update(const Tensor &amax_history, // Launch CUDA kernel constexpr size_t block_size = amax_and_scale_update_impl::bsize; const size_t grid_size = num_scales; - amax_and_scale_update_impl::kernel - <<>>( - static_cast(amax_history.data.dptr), - static_cast(scale.data.dptr), + amax_and_scale_update_impl::kernel<<>>( + static_cast(amax_history.data.dptr), static_cast(scale.data.dptr), static_cast(scale_inv.data.dptr), static_cast(scale_inv_mask.data.dptr), static_cast(updated_amax_history.data.dptr), static_cast(updated_scale.data.dptr), - static_cast(updated_scale_inv.data.dptr), - amax_history_length, - num_scales, - amax_compute_algo_, - scaled_max); + static_cast(updated_scale_inv.data.dptr), amax_history_length, num_scales, + amax_compute_algo_, scaled_max); NVTE_CHECK_CUDA(cudaGetLastError()); } - -void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, +void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, std::vector amax_histories, std::vector scales, std::vector scale_invs, - const std::string &amax_compute_algo, - DType fp8_dtype, - float margin, - cudaStream_t stream) { + const std::string& amax_compute_algo, DType fp8_dtype, + float margin, cudaStream_t stream) { using namespace transformer_engine; // amax value to use for updating scaling factor @@ -407,7 +371,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); // Number of elements in tensor - auto numel = [] (const Tensor *tensor) -> size_t { + auto numel = [](const Tensor* tensor) -> size_t { size_t acc = 1; for (const auto& dim : tensor->data.shape) { acc *= dim; @@ -418,7 +382,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, // Number of tensors in the bulk const size_t num_tensors = amax_histories.size(); size_t num_remaining_tensors = num_tensors; - const int num_kernels = (num_tensors+AMAX_PARAMS_LIMIT-1)/AMAX_PARAMS_LIMIT; + const int num_kernels = (num_tensors + AMAX_PARAMS_LIMIT - 1) / AMAX_PARAMS_LIMIT; size_t amax_history_length = 0; if (num_tensors > 0) { amax_history_length = amax_histories[0]->data.shape[0]; @@ -429,27 +393,26 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, AmaxParams p; for (int iter = 0; iter < num_kernels; iter++) { size_t kernel_num_scales = 0; - size_t kernel_num_tensors = (iter == (num_kernels - 1)) - ? num_remaining_tensors: AMAX_PARAMS_LIMIT; + size_t kernel_num_tensors = + (iter == (num_kernels - 1)) ? num_remaining_tensors : AMAX_PARAMS_LIMIT; for (size_t pi = 0; pi < kernel_num_tensors; pi++) { size_t i = iter * AMAX_PARAMS_LIMIT + pi; // Check tensors int num_scale = amax_histories[i]->data.shape[1]; - NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, - "Found ", dtype_name(amax_histories[i]->data.dtype), "."); - NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, - "Found ", amax_histories[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, - "Expected ", amax_history_length * num_scale, " elements, ", - "but found ", numel(amax_histories[i]), "."); - NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, - "Found ", dtype_name(scales[i]->data.dtype), "."); - NVTE_CHECK(scales[i]->data.shape.size() == 1, - "Found ", scales[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(scales[i]) == num_scale, - "Expected ", num_scale, " elements, ", - "Found ", numel(scales[i]), "."); + NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, "Found ", + dtype_name(amax_histories[i]->data.dtype), "."); + NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ", + amax_histories[i]->data.shape.size(), " dims"); + NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, "Expected ", + amax_history_length * num_scale, " elements, ", "but found ", + numel(amax_histories[i]), "."); + NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ", + dtype_name(scales[i]->data.dtype), "."); + NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(), + " dims"); + NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ", + numel(scales[i]), "."); // amax parameters kernel_num_scales += num_scale; @@ -462,13 +425,8 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, // Launch CUDA kernel size_t grid_size = kernel_num_tensors; const size_t block_size = amax_and_scale_update_impl::bsize; - amax_and_scale_update_impl::kernel_bulk - <<>>( - amax_buffer, - p, - amax_history_length, - amax_compute_algo_, - scaled_max); + amax_and_scale_update_impl::kernel_bulk<<>>( + amax_buffer, p, amax_history_length, amax_compute_algo_, scaled_max); NVTE_CHECK_CUDA(cudaGetLastError()); // shift amax buffer pointer @@ -482,44 +440,25 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, } // namespace delayed_scaling_recipe } // namespace transformer_engine - -void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, - const NVTETensor scale, - const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, - NVTETensor updated_amax_history, - NVTETensor updated_scale, - NVTETensor updated_scale_inv, - const char *amax_compute_algo, - NVTEDType fp8_dtype, - float margin, - cudaStream_t stream) { +void nvte_delayed_scaling_recipe_amax_and_scale_update( + const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, + const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale, + NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + cudaStream_t stream) { NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); using namespace transformer_engine; delayed_scaling_recipe::amax_and_scale_update( - *reinterpret_cast(amax_history), - *reinterpret_cast(scale), - *reinterpret_cast(scale_inv), - *reinterpret_cast(scale_inv_mask), - reinterpret_cast(updated_amax_history), - reinterpret_cast(updated_scale), - reinterpret_cast(updated_scale_inv), - amax_compute_algo, - static_cast(fp8_dtype), - margin, - stream); + *reinterpret_cast(amax_history), *reinterpret_cast(scale), + *reinterpret_cast(scale_inv), *reinterpret_cast(scale_inv_mask), + reinterpret_cast(updated_amax_history), reinterpret_cast(updated_scale), + reinterpret_cast(updated_scale_inv), amax_compute_algo, + static_cast(fp8_dtype), margin, stream); } - void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( - const NVTETensor amax_reduction_buffer, - std::vector amax_histories, - std::vector scales, - std::vector scale_invs, - const char *amax_compute_algo, - NVTEDType fp8_dtype, - float margin, - cudaStream_t stream) { + const NVTETensor amax_reduction_buffer, std::vector amax_histories, + std::vector scales, std::vector scale_invs, + const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) { NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); using namespace transformer_engine; size_t num_tensors = amax_histories.size(); @@ -530,12 +469,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( t_scale_invs.push_back(reinterpret_cast(scale_invs[i])); } delayed_scaling_recipe::amax_and_scale_update_after_reduction( - *reinterpret_cast(amax_reduction_buffer), - t_amax_histories, - t_scales, - t_scale_invs, - amax_compute_algo, - static_cast(fp8_dtype), - margin, - stream); + *reinterpret_cast(amax_reduction_buffer), t_amax_histories, t_scales, + t_scale_invs, amax_compute_algo, static_cast(fp8_dtype), margin, stream); } diff --git a/transformer_engine/common/rmsnorm/rmsnorm.h b/transformer_engine/common/rmsnorm/rmsnorm.h index 3ce606dd54adf188c17de307e2dc602921ac5dc6..8b4e1cf24e70c81a07169f663e94d2e1a87ee768 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm.h +++ b/transformer_engine/common/rmsnorm/rmsnorm.h @@ -8,6 +8,7 @@ #define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ #include + #include #include #include @@ -15,7 +16,6 @@ #include #include "../common.h" - #include "../layer_norm/ln.h" namespace transformer_engine { @@ -47,40 +47,40 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS; template struct FwdTunedRegistrar { - explicit FwdTunedRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({key, f}); - } + explicit FwdTunedRegistrar(FwdFunction f) { + uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); + FWD_TUNED_FUNCS.insert({key, f}); + } }; ////////////////////////////////////////////////////////////////////////////////////////////////// template struct FwdGeneralRegistrar { - explicit FwdGeneralRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } + explicit FwdGeneralRegistrar(FwdFunction f) { + uint64_t key = layer_norm::Types2Key::get(0); + FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); + } }; ////////////////////////////////////////////////////////////////////////////////////////////////// template struct BwdTunedRegistrar { - explicit BwdTunedRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({key, f}); - } + explicit BwdTunedRegistrar(BwdFunction f) { + uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); + BWD_TUNED_FUNCS.insert({key, f}); + } }; ////////////////////////////////////////////////////////////////////////////////////////////////// template struct BwdGeneralRegistrar { - explicit BwdGeneralRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } + explicit BwdGeneralRegistrar(BwdFunction f) { + uint64_t key = layer_norm::Types2Key::get(0); + BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); + } }; } // namespace rmsnorm diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp index 840d84e599935ff1092ab9e2de3dd6797898a217..31e2b4f71e6c8c256d84580fcd411e39c0345713 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp @@ -4,14 +4,13 @@ * See LICENSE for license information. ************************************************************************/ -#include "transformer_engine/rmsnorm.h" - #include #include #include -#include "rmsnorm.h" #include "../common.h" +#include "rmsnorm.h" +#include "transformer_engine/rmsnorm.h" /* @@ -49,85 +48,70 @@ BwdTunedRegistry BWD_TUNED_FUNCS; FwdGeneralRegistry FWD_GENERAL_FUNCS; BwdGeneralRegistry BWD_GENERAL_FUNCS; -FwdFunction &get_fwd_launcher(DType wtype, - DType itype, - DType otype, - DType ctype, +FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, const layer_norm::FwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 - && is_aligned(params.x) - && is_aligned(params.rs) - && is_aligned(params.gamma) - && is_aligned(params.z) - && FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } + // Look for tuned kernel + auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); + auto is_aligned = [](const void *ptr) -> bool { + // Assume vectorized memory accesses are <=16B + return reinterpret_cast(ptr) % 16 == 0; + }; + if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && + is_aligned(params.gamma) && is_aligned(params.z) && FWD_TUNED_FUNCS.count(tuned_key) > 0) { + return FWD_TUNED_FUNCS.at(tuned_key); + } + + // Pick general kernel + auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); + if (FWD_GENERAL_FUNCS.count(general_key) == 0) { + NVTE_ERROR("FWD: Unsupported types."); + } + auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key); + auto func_iter = general_func_map.lower_bound(params.cols); + if (func_iter == general_func_map.end()) { + // Hidden size is too big, need to use multi-CTA + return general_func_map.rbegin()->second; + } else { + return func_iter->second; + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -BwdFunction &get_bwd_launcher(DType wtype, - DType itype, - DType otype, - DType ctype, +BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, const layer_norm::BwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 - && is_aligned(params.x) - && is_aligned(params.rs) - && is_aligned(params.gamma) - && is_aligned(params.dz) - && is_aligned(params.dx) - && is_aligned(params.dgamma) - && is_aligned(params.dgamma_part) - && layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } + // Look for tuned kernel + auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); + auto is_aligned = [](const void *ptr) -> bool { + // Assume vectorized memory accesses are <=16B + return reinterpret_cast(ptr) % 16 == 0; + }; + if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && + is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) && + is_aligned(params.dgamma) && is_aligned(params.dgamma_part) && + layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { + return BWD_TUNED_FUNCS.at(tuned_key); + } + + // Pick general kernel + auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); + if (BWD_GENERAL_FUNCS.count(general_key) == 0) { + NVTE_ERROR("BWD: Unsupported types."); + } + auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key); + auto func_iter = general_func_map.lower_bound(params.cols); + if (func_iter == general_func_map.end()) { + // Hidden size is too big, need to use multi-CTA + return general_func_map.rbegin()->second; + } else { + return func_iter->second; + } } // //////////////////////////////////////////////////////////////////////////////////////////////////// inline size_t product(const std::vector &shape) { - return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>()); + return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>()); } } // namespace rmsnorm @@ -137,213 +121,211 @@ inline size_t product(const std::vector &shape) { void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount, Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) { - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - auto ctype = DType::kFloat32; - - NVTE_CHECK(x.data.shape.size() == 2); - - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(hidden_size == cols); - NVTE_CHECK(epsilon >= 0.f); - - NVTE_CHECK(z->data.shape == x.data.shape); - - NVTE_CHECK(rsigma->data.shape == std::vector{rows}); - NVTE_CHECK(rsigma->data.dtype == ctype); - - rmsnorm::LaunchParams launch_params; - - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; - - // Set the kernel runtime parameters. - rmsnorm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = nullptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } - - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); - } - - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*rsigma, "rsigma"); - - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if (params.fp8_out) { - cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype), - stream); - } - if (launch_params.barrier_size > 0) { - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); + auto itype = x.data.dtype; + auto wtype = gamma.data.dtype; + auto otype = z->data.dtype; + const bool fp8_out = is_fp8_dtype(otype); + auto ctype = DType::kFloat32; + + NVTE_CHECK(x.data.shape.size() == 2); + + const size_t rows = x.data.shape[0]; + const size_t cols = x.data.shape[1]; + const auto hidden_size = gamma.data.shape[0]; + + NVTE_CHECK(hidden_size == cols); + NVTE_CHECK(epsilon >= 0.f); + + NVTE_CHECK(z->data.shape == x.data.shape); + + NVTE_CHECK(rsigma->data.shape == std::vector{rows}); + NVTE_CHECK(rsigma->data.dtype == ctype); + + rmsnorm::LaunchParams launch_params; + + launch_params.multiprocessorCount = multiprocessorCount; + launch_params.stream = stream; + + // Set the kernel runtime parameters. + rmsnorm::FwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data.dptr; + params.mu = nullptr; + params.rs = rsigma->data.dptr; + params.gamma = gamma.data.dptr; + params.beta = nullptr; + params.z = z->data.dptr; + params.epsilon = epsilon; + params.amax = z->amax.dptr; + params.scale = z->scale.dptr; + params.fp8_out = fp8_out; + params.zero_centered_gamma = zero_centered_gamma; + + // Request the kernel launcher. + auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params); + + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + if (launch_params.workspace_bytes == 0) { + launch_params.workspace_bytes = 1; + } + + if (workspace->data.dptr == nullptr) { + NVTE_CHECK(barrier->data.dptr == nullptr); + + workspace->data.dtype = DType::kByte; + workspace->data.shape = {launch_params.workspace_bytes}; + + barrier->data.dtype = DType::kInt32; + barrier->data.shape = {launch_params.barrier_size}; return; + } else { + NVTE_CHECK(workspace->data.dtype == DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); + } + + // Tensor checks are delayed here in order to recover workspace sizes with null data + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*rsigma, "rsigma"); + + if (launch_params.barrier_size > 0) { + params.workspace = workspace->data.dptr; + params.barrier = reinterpret_cast(barrier->data.dptr); + } + + // Clear buffers + if (params.fp8_out) { + cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype), + stream); + } + if (launch_params.barrier_size > 0) { + cudaMemsetAsync(params.barrier, 0, + rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), + stream); + } + + // Launch the kernel. + launcher(launch_params, false); + + return; } void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma, Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream, const int multiprocessorCount, Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) { - using namespace transformer_engine; - - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(rsigma.data.dtype == ctype); - - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(dz.data.shape == x.data.shape); - - const auto rows = x.data.shape[0]; - const auto cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(gamma.data.shape[0] == cols); - - NVTE_CHECK(dx->data.shape == x.data.shape); - NVTE_CHECK(dx->data.dtype == x.data.dtype); - - NVTE_CHECK(dgamma->data.shape == gamma.data.shape); - NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - - rmsnorm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - rmsnorm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = nullptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = nullptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - auto pdw_shape = std::vector{ - static_cast(launch_params.params.ctas_per_col), hidden_size}; - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); - } - - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); + using namespace transformer_engine; + + auto itype = x.data.dtype; + auto wtype = gamma.data.dtype; + auto otype = wtype; + auto ctype = DType::kFloat32; + + NVTE_CHECK(dz.data.dtype == otype); + NVTE_CHECK(rsigma.data.dtype == ctype); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + + const auto rows = x.data.shape[0]; + const auto cols = x.data.shape[1]; + const auto hidden_size = gamma.data.shape[0]; + + NVTE_CHECK(gamma.data.shape[0] == cols); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + rmsnorm::LaunchParams launch_params; + launch_params.stream = stream; + launch_params.multiprocessorCount = multiprocessorCount; + + // Set the kernel runtime parameters. + rmsnorm::BwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data.dptr; + params.mu = nullptr; + params.rs = rsigma.data.dptr; + params.gamma = gamma.data.dptr; + params.dz = dz.data.dptr; + params.dx = dx->data.dptr; + params.dbeta = nullptr; + params.dgamma = dgamma->data.dptr; + params.dbeta_part = nullptr; + params.dgamma_part = dgamma_part->data.dptr; + params.zero_centered_gamma = zero_centered_gamma; + + // Request the kernel launcher. + auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params); + + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + + // Populate shape and dtypes for FW to allocate memory + if (dgamma_part->data.dptr == nullptr) { + dgamma_part->data.dtype = ctype; + dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), + hidden_size}; + + workspace->data.dtype = DType::kByte; + workspace->data.shape = {launch_params.workspace_bytes}; + + barrier->data.dtype = DType::kInt32; + barrier->data.shape = {launch_params.barrier_size}; + + return; + } else { + auto pdw_shape = + std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; + NVTE_CHECK(dgamma_part->data.dtype == ctype); + NVTE_CHECK(dgamma_part->data.shape == pdw_shape); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); + } + + if (launch_params.workspace_bytes > 0) { + NVTE_CHECK(workspace->data.dptr != nullptr); + NVTE_CHECK(workspace->data.dtype == DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); + } + + // Tensor checks are delayed here in order to recover workspace sizes with null data + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + + if (launch_params.barrier_size > 0) { + params.workspace = workspace->data.dptr; + params.barrier = reinterpret_cast(barrier->data.dptr); + cudaMemsetAsync(params.barrier, 0, + rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), + stream); + } + + // Launch the kernel. + launcher(launch_params, false); } } // namespace transformer_engine @@ -364,18 +346,15 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size const NVTETensor x, // Nxhidden_size const NVTETensor rsigma, // N, FP32! const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, - NVTETensor dgamma_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { + NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, + const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { NVTE_API_CALL(nvte_rmsnorm_bwd); using namespace transformer_engine; rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), *reinterpret_cast(rsigma), *reinterpret_cast(gamma), reinterpret_cast(dx), reinterpret_cast(dgamma), reinterpret_cast(dgamma_part), stream, multiprocessorCount, - reinterpret_cast(workspace), reinterpret_cast(barrier), - false); + reinterpret_cast(workspace), reinterpret_cast(barrier), false); } void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size @@ -394,9 +373,8 @@ void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size const NVTETensor x, // Nxhidden_size const NVTETensor rsigma, // N, FP32! const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, - NVTETensor dgamma_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, + NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, + cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { NVTE_API_CALL(nvte_rmsnorm1p_bwd); using namespace transformer_engine; @@ -404,6 +382,5 @@ void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size *reinterpret_cast(rsigma), *reinterpret_cast(gamma), reinterpret_cast(dx), reinterpret_cast(dgamma), reinterpret_cast(dgamma_part), stream, multiprocessorCount, - reinterpret_cast(workspace), reinterpret_cast(barrier), - true); + reinterpret_cast(workspace), reinterpret_cast(barrier), true); } diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh b/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh index 879de64586c62ecb68106966d72adad6e9fc2091..92fd850baa0426d1d3bb074e6441c399e7a73c69 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh +++ b/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh @@ -16,466 +16,462 @@ using namespace transformer_engine; template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel( BwdParams params) { - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { COLS = Ktraits::COLS }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using compute_t = typename Ktraits::compute_t; - using index_t = typename Ktraits::index_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Reducer = typename Ktraits::Reducer; - using reduce_t = typename Reducer::Type; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / Ktraits::WARPS_N; - const index_t warp_n = warp % Ktraits::WARPS_N; - const index_t tid_r = warp_n * THREADS_PER_WARP + lane; - - const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); - - Cvec dzy_sum[LDGS]; - - memset(dzy_sum, 0, sizeof(dzy_sum)); - - compute_t *smem_wgrad = reinterpret_cast(smem_); - char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; - - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); - - Sum sum; - - constexpr float rn = 1.f / static_cast(COLS); - Wvec gamma[LDGS]; - index_t idx = c; + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { COLS = Ktraits::COLS }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using compute_t = typename Ktraits::compute_t; + using index_t = typename Ktraits::index_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Reducer = typename Ktraits::Reducer; + using reduce_t = typename Reducer::Type; + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / Ktraits::WARPS_N; + const index_t warp_n = warp % Ktraits::WARPS_N; + const index_t tid_r = warp_n * THREADS_PER_WARP + lane; + + const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); + + Cvec dzy_sum[LDGS]; + + memset(dzy_sum, 0, sizeof(dzy_sum)); + + compute_t *smem_wgrad = reinterpret_cast(smem_); + char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; + + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); + + Sum sum; + + constexpr float rn = 1.f / static_cast(COLS); + Wvec gamma[LDGS]; + index_t idx = c; #pragma unroll - for (int it = 0; it < LDGS; it++) { - gamma[it].load_from(params.gamma, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } + for (int it = 0; it < LDGS; it++) { + gamma[it].load_from(params.gamma, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the // last blocks with syncthreads! // grid stride over rows #pragma unroll 1 - for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { - const compute_t rs_r = static_cast(params.rs)[row]; - Ivec x[LDGS]; - Ovec dz[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; + for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { + const compute_t rs_r = static_cast(params.rs)[row]; + Ivec x[LDGS]; + Ovec dz[LDGS]; + index_t idx = row * Ktraits::VEC_COLS + c; #pragma unroll - for (int it = 0; it < LDGS; it++) { - dz[it].load_from(params.dz, idx); - x[it].load_from(params.x, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } + for (int it = 0; it < LDGS; it++) { + dz[it].load_from(params.dz, idx); + x[it].load_from(params.x, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } - compute_t dy[LDGS * NUM_ELTS]; - compute_t y[LDGS * NUM_ELTS]; + compute_t dy[LDGS * NUM_ELTS]; + compute_t y[LDGS * NUM_ELTS]; - compute_t mdyy_local = 0.f; + compute_t mdyy_local = 0.f; #pragma unroll - for (int it = 0; it < LDGS; it++) { + for (int it = 0; it < LDGS; it++) { #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t x_tmp = x[it].data.elt[jt]; - compute_t y_tmp = rs_r * (x_tmp); - const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; - compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift; - dy_tmp *= compute_t(dz[it].data.elt[jt]); - compute_t dz_tmp = dz[it].data.elt[jt]; + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t x_tmp = x[it].data.elt[jt]; + compute_t y_tmp = rs_r * (x_tmp); + const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; + compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift; + dy_tmp *= compute_t(dz[it].data.elt[jt]); + compute_t dz_tmp = dz[it].data.elt[jt]; - mdyy_local += dy_tmp * y_tmp; + mdyy_local += dy_tmp * y_tmp; - dy[it * NUM_ELTS + jt] = dy_tmp; - y[it * NUM_ELTS + jt] = y_tmp; + dy[it * NUM_ELTS + jt] = dy_tmp; + y[it * NUM_ELTS + jt] = y_tmp; - dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; - } - } + dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; + } + } - reduce_t result = reducer.allreduce({0, mdyy_local}, sum); - mdyy_local = Get<1>::of(result) * rn; + reduce_t result = reducer.allreduce({0, mdyy_local}, sum); + mdyy_local = Get<1>::of(result) * rn; - Ivec dx[LDGS]; - idx = row * Ktraits::VEC_COLS + c; + Ivec dx[LDGS]; + idx = row * Ktraits::VEC_COLS + c; #pragma unroll - for (int it = 0; it < LDGS; it++) { + for (int it = 0; it < LDGS; it++) { #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t dy_tmp = dy[it * NUM_ELTS + jt]; - compute_t y_tmp = y[it * NUM_ELTS + jt]; - compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp)); - dx[it].data.elt[jt] = dx_tmp; - } - dx[it].store_to(params.dx, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - } // end: grid stride loop + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t dy_tmp = dy[it * NUM_ELTS + jt]; + compute_t y_tmp = y[it * NUM_ELTS + jt]; + compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp)); + dx[it].data.elt[jt] = dx_tmp; + } + dx[it].store_to(params.dx, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + } // end: grid stride loop - if (WARPS_M == 1) { - idx = r * Ktraits::VEC_COLS + c; + if (WARPS_M == 1) { + idx = r * Ktraits::VEC_COLS + c; #pragma unroll - for (int it = 0; it < LDGS; it++) { - dzy_sum[it].store_to(params.dgamma_part, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - } else { - static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, - "Multiple rows per CTA not supported for Multi-CTA."); - // Finalize reduction of part dgamma and dbeta for this CTA - // by reducing over the rows held across the WARPS_M warps + for (int it = 0; it < LDGS; it++) { + dzy_sum[it].store_to(params.dgamma_part, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + } else { + static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, + "Multiple rows per CTA not supported for Multi-CTA."); + // Finalize reduction of part dgamma and dbeta for this CTA + // by reducing over the rows held across the WARPS_M warps - // Assumption: blockSize divides hidden size. - enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; - static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); + // Assumption: blockSize divides hidden size. + enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; + static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); - idx = warp_m * Ktraits::VEC_COLS + tid_r; + idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll - for (int it = 0; it < LDGS; it++) { - dzy_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dzy_sum[NUM_RES]; - memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); - for (int it = 0; it < ROWS_PER_CTA; it++) { - for (int jt = 0; jt < NUM_RES; jt++) { - cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } + for (int it = 0; it < LDGS; it++) { + dzy_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dzy_sum[NUM_RES]; + memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); + for (int it = 0; it < ROWS_PER_CTA; it++) { + for (int jt = 0; jt < NUM_RES; jt++) { + cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } - compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; - for (int jt = 0; jt < NUM_RES; jt++) { - *dgamma_part = cta_dzy_sum[jt]; - dgamma_part += Ktraits::THREADS_PER_CTA; - } + compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; + for (int jt = 0; jt < NUM_RES; jt++) { + *dgamma_part = cta_dzy_sum[jt]; + dgamma_part += Ktraits::THREADS_PER_CTA; } + } } template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_finalize_tuned_kernel( BwdParams params) { - using compute_t = typename Kernel_traits::compute_t; - using weight_t = typename Kernel_traits::weight_t; - using index_t = typename Kernel_traits::index_t; - using Reducer = typename Kernel_traits::Reducer; - using reduce_t = typename Reducer::Type; - - Sum sum; - enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; - - __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; - - constexpr uint32_t bidm = 0; - - const uint32_t bidn = blockIdx.x; - const uint32_t tidx = threadIdx.x; - const uint32_t warp = tidx / THREADS_PER_WARP; - const uint32_t lane = tidx % THREADS_PER_WARP; - - Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); - - const uint32_t c = bidn * THREADS_PER_WARP + lane; - const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; - for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; - col += COL_STRIDE, col_out += COL_STRIDE / 2) { - // Each thread sums over NUM_ELT columns. - Vec dbeta_local, dgamma_local; - memset(&dgamma_local, 0, sizeof(dgamma_local)); - for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) { - index_t idx = row * Kernel_traits::COLS + col; - - Vec dgamma_part; - dgamma_part.load_from(params.dgamma_part, idx); + using compute_t = typename Kernel_traits::compute_t; + using weight_t = typename Kernel_traits::weight_t; + using index_t = typename Kernel_traits::index_t; + using Reducer = typename Kernel_traits::Reducer; + using reduce_t = typename Reducer::Type; + + Sum sum; + enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; + + __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; + + constexpr uint32_t bidm = 0; + + const uint32_t bidn = blockIdx.x; + const uint32_t tidx = threadIdx.x; + const uint32_t warp = tidx / THREADS_PER_WARP; + const uint32_t lane = tidx % THREADS_PER_WARP; + + Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); + + const uint32_t c = bidn * THREADS_PER_WARP + lane; + const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; + constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; + for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; + col += COL_STRIDE, col_out += COL_STRIDE / 2) { + // Each thread sums over NUM_ELT columns. + Vec dbeta_local, dgamma_local; + memset(&dgamma_local, 0, sizeof(dgamma_local)); + for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) { + index_t idx = row * Kernel_traits::COLS + col; + + Vec dgamma_part; + dgamma_part.load_from(params.dgamma_part, idx); #pragma unroll - for (int it = 0; it < NUM_ELT; it++) { - dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; - } - } + for (int it = 0; it < NUM_ELT; it++) { + dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; + } + } - void *smem_gamma = smem_; + void *smem_gamma = smem_; - const int write_row = warp; - const int write_col = lane ^ write_row; - const int write_idx = write_row * THREADS_PER_WARP + write_col; + const int write_row = warp; + const int write_col = lane ^ write_row; + const int write_idx = write_row * THREADS_PER_WARP + write_col; - dgamma_local.store_to(smem_gamma, write_idx); + dgamma_local.store_to(smem_gamma, write_idx); - __syncthreads(); + __syncthreads(); - // It would be probably safe to reuse the first row of smem_gamma - void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + // It would be probably safe to reuse the first row of smem_gamma + void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - // More than one iter iff ROWS_PER_CTA < 32. - for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) { - const int read_row = lane; - const int read_col = w ^ read_row; - const int read_idx = read_row * THREADS_PER_WARP + read_col; + // More than one iter iff ROWS_PER_CTA < 32. + for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) { + const int read_row = lane; + const int read_col = w ^ read_row; + const int read_idx = read_row * THREADS_PER_WARP + read_col; - memset(&dgamma_local, 0, sizeof(dgamma_local)); + memset(&dgamma_local, 0, sizeof(dgamma_local)); - // Load gamma transposed - if (read_row < Kernel_traits::ROWS_PER_CTA) { - dgamma_local.load_from(smem_gamma, read_idx); - } + // Load gamma transposed + if (read_row < Kernel_traits::ROWS_PER_CTA) { + dgamma_local.load_from(smem_gamma, read_idx); + } // Call reducer on the loaded value(s) and convert. #pragma unroll - for (int it = 0; it < NUM_ELT; it++) { - compute_t g_i = dgamma_local.data.elt[it]; - g_i = reducer.allreduce(g_i, sum); + for (int it = 0; it < NUM_ELT; it++) { + compute_t g_i = dgamma_local.data.elt[it]; + g_i = reducer.allreduce(g_i, sum); - dgamma_local.data.elt[it] = g_i; - } + dgamma_local.data.elt[it] = g_i; + } - // Leader stores the result at the current column. - if (lane == 0) { - dgamma_local.store_to(smem_gamma_out, w); - } - } + // Leader stores the result at the current column. + if (lane == 0) { + dgamma_local.store_to(smem_gamma_out, w); + } + } - // All writes done. - __syncthreads(); + // All writes done. + __syncthreads(); - // Pack and store: 2-wide stores with half the threads. - if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) { - using src_t = typename TypeToVec2::Type; - using dst_t = typename TypeToVec2::Type; - Vec dgamma_vec2; - Vec dgamma_out2; + // Pack and store: 2-wide stores with half the threads. + if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) { + using src_t = typename TypeToVec2::Type; + using dst_t = typename TypeToVec2::Type; + Vec dgamma_vec2; + Vec dgamma_out2; - dgamma_vec2.load_from(smem_gamma_out, lane); + dgamma_vec2.load_from(smem_gamma_out, lane); #pragma unroll - for (int it = 0; it < NUM_ELT; it++) { - dgamma_out2.data.elt[it] = - Converter::convert(dgamma_vec2.data.elt[it]); - } - dgamma_out2.store_to(params.dgamma, col_out); - } + for (int it = 0; it < NUM_ELT; it++) { + dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); + } + dgamma_out2.store_to(params.dgamma, col_out); } + } } template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel( BwdParams params) { - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using compute_t = typename Ktraits::compute_t; - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - - const index_t tidx = threadIdx.x; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t bdimm = WARPS_M; - const index_t bdimn = WARPS_N * THREADS_PER_WARP; - const index_t bidm = blockIdx.x / params.ctas_per_row; - const index_t bidn = blockIdx.x % params.ctas_per_row; - - const index_t gdimm = bdimm * params.ctas_per_col; - const index_t gdimn = bdimn * params.ctas_per_row; - const index_t gidm = bidm * bdimm + warp_m; - const index_t gidn = - (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP + - lane); // Order threads by warp x cta x lane - - // Objects for weight grads - Cvec dzy_sum[LDGS]; - memset(dzy_sum, 0, sizeof(dzy_sum)); - - // Objects for stats reductions - using reduce_t = typename Ktraits::Reducer::Type; - using Reducer = DynamicReducer; - constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; - __shared__ char smem_[SMEM_BYTES]; - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); - Sum sum; - const compute_t rn = 1.f / static_cast(params.cols); - - // Load weights - Cvec gamma[LDGS]; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using compute_t = typename Ktraits::compute_t; + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + + const index_t tidx = threadIdx.x; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t bdimm = WARPS_M; + const index_t bdimn = WARPS_N * THREADS_PER_WARP; + const index_t bidm = blockIdx.x / params.ctas_per_row; + const index_t bidn = blockIdx.x % params.ctas_per_row; + + const index_t gdimm = bdimm * params.ctas_per_col; + const index_t gdimn = bdimn * params.ctas_per_row; + const index_t gidm = bidm * bdimm + warp_m; + const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP + + lane); // Order threads by warp x cta x lane + + // Objects for weight grads + Cvec dzy_sum[LDGS]; + memset(dzy_sum, 0, sizeof(dzy_sum)); + + // Objects for stats reductions + using reduce_t = typename Ktraits::Reducer::Type; + using Reducer = DynamicReducer; + constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; + __shared__ char smem_[SMEM_BYTES]; + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); + Sum sum; + const compute_t rn = 1.f / static_cast(params.cols); + + // Load weights + Cvec gamma[LDGS]; #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - Wvec gamma_in; - gamma_in.load_from_elts(params.gamma, col, params.cols - col); - gamma_in.to(gamma[it]); + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + Wvec gamma_in; + gamma_in.load_from_elts(params.gamma, col, params.cols - col); + gamma_in.to(gamma[it]); + } + + for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { + const int row = cta_row + warp_m; + compute_t rs = 0.f; + if (row < params.rows) { + rs = static_cast(params.rs)[row]; } - for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { - const int row = cta_row + warp_m; - compute_t rs = 0.f; - if (row < params.rows) { - rs = static_cast(params.rs)[row]; - } - - Cvec dy[LDGS]; - Cvec y[LDGS]; - compute_t mdy = 0.f; - compute_t mdyy = 0.f; + Cvec dy[LDGS]; + Cvec y[LDGS]; + compute_t mdy = 0.f; + compute_t mdyy = 0.f; #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - Ivec x; - Ovec dz; - x.load_from_elts(params.x, row * params.cols + col, params.cols - col); - dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col); + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + Ivec x; + Ovec dz; + x.load_from_elts(params.x, row * params.cols + col, params.cols - col); + dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col); #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t x_ij = x.data.elt[jt]; - compute_t y_ij = rs * (x_ij); - const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; - compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift; - compute_t dz_ij = dz.data.elt[jt]; - compute_t dy_ij = g_ij * dz_ij; - - y[it].data.elt[jt] = y_ij; - dy[it].data.elt[jt] = dy_ij; - - mdy += dy_ij; - mdyy += dy_ij * y_ij; - - dzy_sum[it].data.elt[jt] += dz_ij * y_ij; - } - } + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t x_ij = x.data.elt[jt]; + compute_t y_ij = rs * (x_ij); + const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; + compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift; + compute_t dz_ij = dz.data.elt[jt]; + compute_t dy_ij = g_ij * dz_ij; + + y[it].data.elt[jt] = y_ij; + dy[it].data.elt[jt] = dy_ij; + + mdy += dy_ij; + mdyy += dy_ij * y_ij; + + dzy_sum[it].data.elt[jt] += dz_ij * y_ij; + } + } - // Reduce over row - reduce_t result = reducer.allreduce({mdy, mdyy}, sum); - mdy = Get<0>::of(result) * rn; - mdyy = Get<1>::of(result) * rn; + // Reduce over row + reduce_t result = reducer.allreduce({mdy, mdyy}, sum); + mdy = Get<0>::of(result) * rn; + mdyy = Get<1>::of(result) * rn; // Compute dx #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - Ivec dx; + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + Ivec dx; #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t dy_ij = dy[it].data.elt[jt]; - compute_t y_ij = y[it].data.elt[jt]; - dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij)); - } - dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); - } + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t dy_ij = dy[it].data.elt[jt]; + compute_t y_ij = y[it].data.elt[jt]; + dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij)); + } + dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); } + } - if constexpr (WARPS_M == 1) { + if constexpr (WARPS_M == 1) { // Write out local weight grad contributions #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, - params.cols - col); - } - } else { - // Reduce weight grad contributions within CTA before writing - __shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1]; + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col); + } + } else { + // Reduce weight grad contributions within CTA before writing + __shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1]; - // Reduce dzy - __syncthreads(); + // Reduce dzy + __syncthreads(); #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - if (it != warp_m) { - dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]); - } - } - __syncthreads(); + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + if (it != warp_m) { + dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]); + } + } + __syncthreads(); #pragma unroll - for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols; - it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) { + for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols; + it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) { #pragma unroll - for (int kt = 0; kt < WARPS_M; kt++) { - if (kt != warp_m) { + for (int kt = 0; kt < WARPS_M; kt++) { + if (kt != warp_m) { #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt]; - } - } - } - dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, - params.cols - col); + for (int jt = 0; jt < NUM_ELTS; jt++) { + dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt]; + } } + } + dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col); } + } } template __global__ __launch_bounds__( WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BwdParams params) { - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; - using Wvec = Vec; - using Cvec = Vec; - - const int lane = threadIdx.x % THREADS_PER_WARP; - const int warp_m = threadIdx.y; - const int warp_n = threadIdx.x / THREADS_PER_WARP; - const int col = blockIdx.x * blockDim.x + threadIdx.x; - - // Load grad contributions and accumulate locally - Cvec dgamma; - dgamma.clear(); - for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) { - Cvec dgamma_part; - dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col); + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; + using Wvec = Vec; + using Cvec = Vec; + + const int lane = threadIdx.x % THREADS_PER_WARP; + const int warp_m = threadIdx.y; + const int warp_n = threadIdx.x / THREADS_PER_WARP; + const int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Load grad contributions and accumulate locally + Cvec dgamma; + dgamma.clear(); + for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) { + Cvec dgamma_part; + dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col); #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - dgamma.data.elt[jt] += dgamma_part.data.elt[jt]; - } + for (int jt = 0; jt < NUM_ELTS; jt++) { + dgamma.data.elt[jt] += dgamma_part.data.elt[jt]; } + } - // Reduce dgamma within CTA - __shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1]; - dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]); + // Reduce dgamma within CTA + __shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1]; + dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]); #pragma unroll - for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) { - __syncthreads(); - if (warp_m < nrows) { + for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) { + __syncthreads(); + if (warp_m < nrows) { #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - vecs_shared[warp_m][warp_n][lane].data.elt[jt] += - vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt]; - } - } - } - if (warp_m == 0 && col < params.cols) { - Wvec dgamma_out; - vecs_shared[warp_m][warp_n][lane].to(dgamma_out); - dgamma_out.store_to_elts(params.dgamma, col, params.cols - col); + for (int jt = 0; jt < NUM_ELTS; jt++) { + vecs_shared[warp_m][warp_n][lane].data.elt[jt] += + vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt]; + } } + } + if (warp_m == 0 && col < params.cols) { + Wvec dgamma_out; + vecs_shared[warp_m][warp_n][lane].to(dgamma_out); + dgamma_out.store_to_elts(params.dgamma, col, params.cols - col); + } } } // namespace rmsnorm diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 552cd1b4bc23b1806de566e66a5386ba9231697c..3215a6a9d403f38850db1c000af74ac0a2f578f2 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -13,167 +13,147 @@ using namespace transformer_engine::rmsnorm; template -void launch_tuned_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) - using Kernel_traits = - rmsnorm::Kernel_traits; - auto kernel = &rmsnorm_bwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * - Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::reduce_t) * 2; - } - return; +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = + rmsnorm::Kernel_traits; + auto kernel = &rmsnorm_bwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::reduce_t) * 2; } - - if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); - } - - using Kernel_traits_f = - Kernel_traits_finalize; - - auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel; - kernel_f<<>>( + return; + } + + if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, + stream); + } + + using Kernel_traits_f = + Kernel_traits_finalize; + + auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel; + kernel_f<<>>( + launch_params.params); } template -void launch_general_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Instantiate kernel - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_bwd_general_kernel; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = (ctas_per_col * WARPS_M * ctas_per_row * - sizeof(typename Kernel_traits::reduce_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Instantiate kernel + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_bwd_general_kernel; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, + Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_size = 2 * ctas_per_col; + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); } - - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); - auto kernel_final = - &rmsnorm_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } + + // Launch finalization kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); + auto kernel_final = + &rmsnorm_bwd_finalize_general_kernel; + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + kernel_final<<>>(launch_params.params); } //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_BWD_TUNED_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \ - BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams \ - &launch_params, \ - const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_BWD_GENERAL_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \ - BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams \ - &launch_params, \ - const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) +#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ + WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_tuned_(launch_params, \ + configure_params); \ + } \ + static BwdTunedRegistrar \ + reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) + +#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ + BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_general_(launch_params, configure_params); \ + } \ + static BwdGeneralRegistrar \ + reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index bce89fafb1ebc5edbf76ed0b244c251864125023..3c8e121540dd316326f39650d87f3746bc3a974d 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -13,122 +13,120 @@ using namespace transformer_engine::rmsnorm; template -void launch_tuned_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_fwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * - Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::Stats::stats_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_fwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::Stats::stats_t) * 2; } + return; + } + + if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) + Kernel_traits::SMEM_BYTES_FWD, stream); + } } template -void launch_general_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_fwd_general_kernel; - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_fwd_general_kernel; + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_size = 2 * ctas_per_col; + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); } + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ - CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_tuned_( \ - launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_general_( \ - launch_params, configure_params); \ - } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) +#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ + WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_tuned_(launch_params, configure_params); \ + } \ + static FwdTunedRegistrar \ + reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) + +#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ + BYTES_PER_LDG) \ + void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_general_(launch_params, configure_params); \ + } \ + static FwdGeneralRegistrar \ + reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh index e773c58799cb90156777ef50850b760f9d0fdcb1..a1cfc2293ca2e249f0540c9b453f1fb12ba2ce8e 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -9,6 +9,7 @@ #include #include + #include "../utils.cuh" namespace transformer_engine { @@ -18,261 +19,260 @@ using namespace transformer_engine; template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel( FwdParams params) { - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - - using Stats = typename Ktraits::Stats; - using stats_t = typename Stats::stats_t; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t r = bidm * ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); - - compute_t *rs_ptr = static_cast(params.rs); - - Wvec gamma[LDGS]; - index_t idx = c; + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); + + compute_t *rs_ptr = static_cast(params.rs); + + Wvec gamma[LDGS]; + index_t idx = c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + gamma[it].load_from(params.gamma, idx); + idx += VEC_COLS_PER_LDG; + } + + constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); + + compute_t scale; + if (params.fp8_out) { + scale = *reinterpret_cast(params.scale); + } + compute_t amax = 0; + + for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { + Ivec x[LDGS]; + index_t idx = row * Ktraits::VEC_COLS + c; + compute_t xf[LDGS * NUM_ELTS]; #pragma unroll for (int it = 0; it < LDGS; it++) { - gamma[it].load_from(params.gamma, idx); - idx += VEC_COLS_PER_LDG; + x[it].load_from(params.x, idx); +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t x_ij = compute_t(x[it].data.elt[jt]); + xf[it * NUM_ELTS + jt] = x_ij; + } + idx += VEC_COLS_PER_LDG; } - constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); + stats_t s = stats.compute(xf, rn); + + compute_t mu = Get<0>::of(s); + compute_t m2 = Get<1>::of(s); + // reciprocal of root mean square + // we could optimize here to count mean square directly + compute_t rs = rsqrtf(rn * m2 + mu * mu + params.epsilon); - compute_t scale; - if (params.fp8_out) { - scale = *reinterpret_cast(params.scale); + if (bidn == 0 && warp_n == 0 && lane == 0) { + rs_ptr[row] = rs; } - compute_t amax = 0; - for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { - Ivec x[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; - compute_t xf[LDGS * NUM_ELTS]; + Ovec z[LDGS]; + idx = row * Ktraits::VEC_COLS + c; #pragma unroll - for (int it = 0; it < LDGS; it++) { - x[it].load_from(params.x, idx); + for (int it = 0; it < LDGS; it++) { #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t x_ij = compute_t(x[it].data.elt[jt]); - xf[it * NUM_ELTS + jt] = x_ij; - } - idx += VEC_COLS_PER_LDG; + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]); + compute_t g_ij = gamma[it].data.elt[jt]; + if (params.zero_centered_gamma) { + g_ij += 1; } + compute_t temp_output = g_ij * y_ij; - stats_t s = stats.compute(xf, rn); - - compute_t mu = Get<0>::of(s); - compute_t m2 = Get<1>::of(s); - // reciprocal of root mean square - // we could optimize here to count mean square directly - compute_t rs = rsqrtf(rn * m2 + mu * mu + params.epsilon); - - if (bidn == 0 && warp_n == 0 && lane == 0) { - rs_ptr[row] = rs; + if (params.fp8_out) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(temp_output)); + temp_output = temp_output * scale; } - Ovec z[LDGS]; - idx = row * Ktraits::VEC_COLS + c; -#pragma unroll - for (int it = 0; it < LDGS; it++) { -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]); - compute_t g_ij = gamma[it].data.elt[jt]; - if (params.zero_centered_gamma) { - g_ij += 1; - } - compute_t temp_output = g_ij * y_ij; - - if (params.fp8_out) { - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(temp_output)); - temp_output = temp_output * scale; - } - - z[it].data.elt[jt] = output_t(temp_output); - } - z[it].store_to(params.z, idx); - idx += VEC_COLS_PER_LDG; - } + z[it].data.elt[jt] = output_t(temp_output); + } + z[it].store_to(params.z, idx); + idx += VEC_COLS_PER_LDG; } - if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0 && threadIdx.y == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + } + if (params.fp8_out) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0 && threadIdx.y == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } } template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel( FwdParams params) { - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - - const index_t tidx = threadIdx.x; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t bdimm = WARPS_M; - const index_t bdimn = WARPS_N * THREADS_PER_WARP; - const index_t bidm = blockIdx.x / params.ctas_per_row; - const index_t bidn = blockIdx.x % params.ctas_per_row; - - const index_t gdimm = bdimm * params.ctas_per_col; - const index_t gdimn = bdimn * params.ctas_per_row; - const index_t gidm = bidm * bdimm + warp_m; - const index_t gidn = - (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP + - lane); // Order threads by warp x cta x lane - - // Objects for stats reductions - using Reducer = DynamicReducer; - constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; - __shared__ char smem_[SMEM_BYTES]; - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); - Sum sum; - const compute_t rn = 1.f / static_cast(params.cols); - - // Load weights - Cvec gamma[LDGS]; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + + const index_t tidx = threadIdx.x; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t bdimm = WARPS_M; + const index_t bdimn = WARPS_N * THREADS_PER_WARP; + const index_t bidm = blockIdx.x / params.ctas_per_row; + const index_t bidn = blockIdx.x % params.ctas_per_row; + + const index_t gdimm = bdimm * params.ctas_per_col; + const index_t gdimn = bdimn * params.ctas_per_row; + const index_t gidm = bidm * bdimm + warp_m; + const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP + + lane); // Order threads by warp x cta x lane + + // Objects for stats reductions + using Reducer = DynamicReducer; + constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; + __shared__ char smem_[SMEM_BYTES]; + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); + Sum sum; + const compute_t rn = 1.f / static_cast(params.cols); + + // Load weights + Cvec gamma[LDGS]; #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + Wvec gamma_in; + gamma_in.load_from_elts(params.gamma, col, params.cols - col); + gamma_in.to(gamma[it]); + } + + // fp8 factors + compute_t scale; + if (params.fp8_out) { + scale = *reinterpret_cast(params.scale); + } + compute_t amax = 0; + + for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { + const int row = cta_row + warp_m; + + // Load input + Cvec x[LDGS]; +#pragma unroll + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; it++, col += gdimn * NUM_ELTS) { - Wvec gamma_in; - gamma_in.load_from_elts(params.gamma, col, params.cols - col); - gamma_in.to(gamma[it]); - } - - // fp8 factors - compute_t scale; - if (params.fp8_out) { - scale = *reinterpret_cast(params.scale); + Ivec x_in; + x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col); + x_in.to(x[it]); } - compute_t amax = 0; - for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { - const int row = cta_row + warp_m; - - // Load input - Cvec x[LDGS]; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - Ivec x_in; - x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col); - x_in.to(x[it]); - } - - // Compute variance - compute_t sqsigma = 0.f; + // Compute variance + compute_t sqsigma = 0.f; #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - if (col + jt < params.cols) { - compute_t diff = x[it].data.elt[jt]; - sqsigma += diff * diff; - } - } + for (int jt = 0; jt < NUM_ELTS; jt++) { + if (col + jt < params.cols) { + compute_t diff = x[it].data.elt[jt]; + sqsigma += diff * diff; } - sqsigma = reducer.allreduce(sqsigma, sum) * rn; - compute_t rs = rsqrtf(sqsigma + params.epsilon); + } + } + sqsigma = reducer.allreduce(sqsigma, sum) * rn; + compute_t rs = rsqrtf(sqsigma + params.epsilon); - // Write statistics - if (gidn == 0 && row < params.rows) { - compute_t *rs_ptr = static_cast(params.rs); - rs_ptr[row] = rs; - } + // Write statistics + if (gidn == 0 && row < params.rows) { + compute_t *rs_ptr = static_cast(params.rs); + rs_ptr[row] = rs; + } // Compute output #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - // Compute output values - Cvec z; -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t y_ij = rs * (x[it].data.elt[jt]); - compute_t g_ij = gamma[it].data.elt[jt]; - if (params.zero_centered_gamma) { - g_ij += 1; - } - z.data.elt[jt] = g_ij * y_ij; - } - - // Apply fp8 factors - if (params.fp8_out) { + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; + it++, col += gdimn * NUM_ELTS) { + // Compute output values + Cvec z; #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - if (col + jt < params.cols) { - compute_t z_ij = z.data.elt[jt]; - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); - z.data.elt[jt] = z_ij * scale; - } - } - } - - // Store output - Ovec z_out; - z.to(z_out); - z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col); + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t y_ij = rs * (x[it].data.elt[jt]); + compute_t g_ij = gamma[it].data.elt[jt]; + if (params.zero_centered_gamma) { + g_ij += 1; } - } + z.data.elt[jt] = g_ij * y_ij; + } - // Finalize fp8 factors - if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Apply fp8 factors + if (params.fp8_out) { +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + if (col + jt < params.cols) { + compute_t z_ij = z.data.elt[jt]; + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(z_ij)); + z.data.elt[jt] = z_ij * scale; + } } + } + + // Store output + Ovec z_out; + z.to(z_out); + z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col); + } + } + + // Finalize fp8 factors + if (params.fp8_out) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } } } // namespace rmsnorm diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 6cdbc532029979f8a51cb0155e50fe264cc47950..5cfab2f8cf8d1a1630dc49afeda8e5d2e76615f8 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -5,80 +5,65 @@ ************************************************************************/ #include + #include "common.h" namespace transformer_engine { size_t typeToSize(const transformer_engine::DType type) { - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, - return TypeInfo::size; - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, + return TypeInfo::size;); // NOLINT(*) } bool is_fp8_dtype(const transformer_engine::DType t) { - return t == transformer_engine::DType::kFloat8E4M3 || - t == transformer_engine::DType::kFloat8E5M2; + return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2; } void CheckInputTensor(const Tensor &t, const std::string &name) { const DType type = t.data.dtype; if (is_fp8_dtype(type)) { // FP8 input needs to have scale_inv - NVTE_CHECK(t.scale_inv.dptr != nullptr, - "FP8 input " + name + " must have inverse of scale."); + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale."); NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); - NVTE_CHECK(t.scale_inv.shape == std::vector{ 1 }); + NVTE_CHECK(t.scale_inv.shape == std::vector{1}); } else { - NVTE_CHECK(t.scale.dptr == nullptr, - "Scale is not supported for non-FP8 input " + name + "."); - NVTE_CHECK(t.amax.dptr == nullptr, - "Amax is not supported for non-FP8 input " + name + "."); + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + "."); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + "."); NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input " + name + "."); } - NVTE_CHECK(t.data.dptr != nullptr, - "Input " + name + " is not allocated!"); + NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!"); } void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { const DType type = t.data.dtype; if (is_fp8_dtype(type)) { // FP8 output needs to have scale, amax and scale_inv - NVTE_CHECK(t.amax.dptr != nullptr, - "FP8 output " + name + " must have amax tensor."); + NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor."); NVTE_CHECK(t.amax.dtype == DType::kFloat32); - NVTE_CHECK(t.amax.shape == std::vector{ 1 }); - NVTE_CHECK(t.scale_inv.dptr != nullptr, - "FP8 output " + name + " must have scale."); + NVTE_CHECK(t.amax.shape == std::vector{1}); + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale."); NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); - NVTE_CHECK(t.scale_inv.shape == std::vector{ 1 }); - NVTE_CHECK(t.scale.dptr != nullptr, - "FP8 output " + name + " must have inverse of scale."); + NVTE_CHECK(t.scale_inv.shape == std::vector{1}); + NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale."); NVTE_CHECK(t.scale.dtype == DType::kFloat32); - NVTE_CHECK(t.scale.shape == std::vector{ 1 }); + NVTE_CHECK(t.scale.shape == std::vector{1}); } else { - NVTE_CHECK(t.scale.dptr == nullptr, - "Scale is not supported for non-FP8 output " + name + "."); - NVTE_CHECK(t.amax.dptr == nullptr, - "Amax is not supported for non-FP8 output " + name + "."); + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + "."); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + "."); NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output " + name + "."); } if (!allow_empty) { - NVTE_CHECK(t.data.dptr != nullptr, - "Output " + name + " is not allocated!"); + NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!"); } } } // namespace transformer_engine -NVTETensor nvte_create_tensor(void *dptr, - const NVTEShape shape, - const NVTEDType dtype, - float *amax, - float *scale, - float *scale_inv) { +NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax, + float *scale, float *scale_inv) { transformer_engine::Tensor *ret = new transformer_engine::Tensor; ret->data.dptr = dptr; ret->data.shape = std::vector(shape.data, shape.data + shape.ndim); @@ -97,11 +82,11 @@ void nvte_destroy_tensor(NVTETensor tensor) { NVTEDType nvte_tensor_type(const NVTETensor tensor) { return static_cast( - reinterpret_cast(tensor)->data.dtype); + reinterpret_cast(tensor)->data.dtype); } NVTEShape nvte_tensor_shape(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); + const auto &t = *reinterpret_cast(tensor); NVTEShape ret; ret.data = t.data.shape.data(); ret.ndim = t.data.shape.size(); @@ -109,40 +94,40 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { } void *nvte_tensor_data(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); + const auto &t = *reinterpret_cast(tensor); return t.data.dptr; } float *nvte_tensor_amax(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); + const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32, "Tensor's amax must have Float32 type!"); - return reinterpret_cast(t.amax.dptr); + return reinterpret_cast(t.amax.dptr); } float *nvte_tensor_scale(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); + const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32, "Tensor's scale must have Float32 type!"); - return reinterpret_cast(t.scale.dptr); + return reinterpret_cast(t.scale.dptr); } float *nvte_tensor_scale_inv(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); + const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32, "Tensor's inverse of scale must have Float32 type!"); - return reinterpret_cast(t.scale_inv.dptr); + return reinterpret_cast(t.scale_inv.dptr); } -void nvte_tensor_pack_create(NVTETensorPack* pack) { +void nvte_tensor_pack_create(NVTETensorPack *pack) { for (int i = 0; i < pack->MAX_SIZE; i++) { - pack->tensors[i] = reinterpret_cast(new transformer_engine::Tensor); + pack->tensors[i] = reinterpret_cast(new transformer_engine::Tensor); } } -void nvte_tensor_pack_destroy(NVTETensorPack* pack) { +void nvte_tensor_pack_destroy(NVTETensorPack *pack) { for (int i = 0; i < pack->MAX_SIZE; i++) { - auto *t = reinterpret_cast(pack->tensors[i]); - delete t; + auto *t = reinterpret_cast(pack->tensors[i]); + delete t; } } diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 71e1ed29f3a45ebe5fb5ad74152d031eebdb4f57..b7d8b87dff0becd09adebe021e0ccd283b785402 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -4,13 +4,12 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include -#include - #include "../common.h" #include "../util/rtc.h" #include "../util/string.h" @@ -49,26 +48,18 @@ struct KernelConfig { /* Elements per L1 cache store to transposed output */ size_t elements_per_store_t = 0; - KernelConfig(size_t row_length, - size_t num_rows, - size_t itype_size, - size_t otype_size, - size_t load_size_, - size_t store_size_) - : load_size{load_size_} - , store_size{store_size_} { + KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t otype_size, + size_t load_size_, size_t store_size_) + : load_size{load_size_}, store_size{store_size_} { // Check that tiles are correctly aligned constexpr size_t cache_line_size = 128; - if (load_size % itype_size != 0 - || store_size % otype_size != 0 - || cache_line_size % itype_size != 0 - || cache_line_size % otype_size != 0) { + if (load_size % itype_size != 0 || store_size % otype_size != 0 || + cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) { return; } const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size; const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size; - valid = (row_length % row_tile_elements == 0 - && num_rows % col_tile_elements == 0); + valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0); if (!valid) { return; } @@ -80,12 +71,9 @@ struct KernelConfig { constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), static_cast(cuda::sm_count())); - elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) - / itype_size); - elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) - / otype_size); - elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size) - / otype_size); + elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) / itype_size); + elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) / otype_size); + elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size) / otype_size); } /* Compare by estimated cost */ @@ -104,8 +92,8 @@ struct KernelConfig { const auto &st2 = other.elements_per_store_t; const auto &p2 = other.active_sm_count; const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2; - const auto cost1 = (scale/l1 + scale/sc1 + scale/st1) / p1; - const auto cost2 = (scale/l2 + scale/sc2 + scale/st2) / p2; + const auto cost1 = (scale / l1 + scale / sc1 + scale / st1) / p1; + const auto cost2 = (scale / l2 + scale / sc2 + scale / st2) / p2; return cost1 < cost2; } else { return this->valid && !other.valid; @@ -114,16 +102,14 @@ struct KernelConfig { }; template -__global__ void -__launch_bounds__(block_size) -cast_transpose_general_kernel(const IType * __restrict__ const input, - const CType * __restrict__ const noop, - OType * __restrict__ const output_c, - OType * __restrict__ const output_t, - const CType * __restrict__ const scale_ptr, - CType * __restrict__ const amax_ptr, - const size_t row_length, - const size_t num_rows) { +__global__ void __launch_bounds__(block_size) + cast_transpose_general_kernel(const IType *__restrict__ const input, + const CType *__restrict__ const noop, + OType *__restrict__ const output_c, + OType *__restrict__ const output_t, + const CType *__restrict__ const scale_ptr, + CType *__restrict__ const amax_ptr, const size_t row_length, + const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -165,16 +151,16 @@ cast_transpose_general_kernel(const IType * __restrict__ const input, // Note: Each thread loads num_iterations subtiles, computes amax, // casts type, and transposes in registers. OVecT local_output_t[nvec_in][num_iterations]; - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidy + iter * bdimy; const size_t j1 = tidx; - #pragma unroll +#pragma unroll for (size_t i2 = 0; i2 < nvec_out; ++i2) { const size_t row = tile_row + i1 * nvec_out + i2; const size_t col = tile_col + j1 * nvec_in; if (row < num_rows) { - #pragma unroll +#pragma unroll for (size_t j2 = 0; j2 < nvec_in; ++j2) { if (col + j2 < row_length) { const CType in = input[row * row_length + col + j2]; @@ -190,24 +176,24 @@ cast_transpose_general_kernel(const IType * __restrict__ const input, } // Copy transposed output from registers to global memory - __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; - #pragma unroll + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll for (size_t j2 = 0; j2 < nvec_in; ++j2) { - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidy + iter * bdimy; const size_t j1 = tidx; shared_output_t[j1][i1] = local_output_t[j2][iter]; } __syncthreads(); - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidx; const size_t j1 = tidy + iter * bdimy; const size_t row = tile_row + i1 * nvec_out; const size_t col = tile_col + j1 * nvec_in + j2; if (col < row_length) { - #pragma unroll +#pragma unroll for (size_t i2 = 0; i2 < nvec_out; ++i2) { if (row + i2 < num_rows) { output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; @@ -229,18 +215,15 @@ cast_transpose_general_kernel(const IType * __restrict__ const input, } // namespace -void cast_transpose(const Tensor &input, - const Tensor &noop, - Tensor *cast_output_, - Tensor *transposed_output_, - cudaStream_t stream) { +void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output_, + Tensor *transposed_output_, cudaStream_t stream) { Tensor &cast_output = *cast_output_; Tensor &transposed_output = *transposed_output_; // Check no-op flag if (noop.data.dptr != nullptr) { size_t numel = 1; - for (const auto& dim : noop.data.shape) { + for (const auto &dim : noop.data.shape) { numel *= dim; } NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, "."); @@ -254,16 +237,14 @@ void cast_transpose(const Tensor &input, CheckOutputTensor(transposed_output, "transposed_output"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions."); - NVTE_CHECK(transposed_output.data.shape.size() == 2, - "Transposed output must have 2 dimensions."); + NVTE_CHECK(transposed_output.data.shape.size() == 2, "Transposed output must have 2 dimensions."); const size_t row_length = input.data.shape[1]; const size_t num_rows = input.data.shape[0]; NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output."); NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output."); NVTE_CHECK(transposed_output.data.shape[0] == row_length, "Wrong dimension of transposed output."); - NVTE_CHECK(transposed_output.data.shape[1] == num_rows, - "Wrong dimension of transposed output."); + NVTE_CHECK(transposed_output.data.shape[1] == num_rows, "Wrong dimension of transposed output."); // Check tensor pointers NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); @@ -276,118 +257,111 @@ void cast_transpose(const Tensor &input, NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, "Cast and transposed outputs need to share scale tensor."); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output.data.dtype, OutputType, - constexpr const char *itype_name = TypeInfo::name; - constexpr const char *otype_name = TypeInfo::name; - constexpr size_t itype_size = sizeof(InputType); - constexpr size_t otype_size = sizeof(OutputType); - - // Choose between runtime-compiled or statically-compiled kernel - const bool aligned = (row_length % THREADS_PER_WARP == 0 - && num_rows % THREADS_PER_WARP == 0); - if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel - // Pick kernel config - std::vector kernel_configs; - kernel_configs.reserve(16); - auto add_config = [&](size_t load_size, size_t store_size) { - kernel_configs.emplace_back(row_length, num_rows, - itype_size, otype_size, - load_size, store_size); - }; - add_config(8, 8); - add_config(4, 8); add_config(8, 4); - add_config(4, 4); - add_config(2, 8); add_config(8, 2); - add_config(2, 4); add_config(4, 2); - add_config(2, 2); - add_config(1, 8); add_config(8, 1); - add_config(1, 4); add_config(4, 1); - add_config(1, 2); add_config(2, 1); - add_config(1, 1); - const auto &kernel_config = *std::min_element(kernel_configs.begin(), - kernel_configs.end()); - NVTE_CHECK(kernel_config.valid, "invalid kernel config"); - const size_t load_size = kernel_config.load_size; - const size_t store_size = kernel_config.store_size; - const size_t num_blocks = kernel_config.num_blocks; - - // Compile NVRTC kernel if needed and launch - auto& rtc_manager = rtc::KernelManager::instance(); - const std::string kernel_label = concat_strings("cast_transpose" - ",itype=", itype_name, - ",otype=", otype_name, - ",load_size=", load_size, - ",store_size=", store_size); - if (!rtc_manager.is_compiled(kernel_label)) { - std::string code = string_code_transpose_rtc_cast_transpose_cu; - code = regex_replace(code, "__ITYPE__", itype_name); - code = regex_replace(code, "__OTYPE__", otype_name); - code = regex_replace(code, "__LOAD_SIZE__", load_size); - code = regex_replace(code, "__STORE_SIZE__", store_size); - code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); - code = regex_replace(code, "__BLOCK_SIZE__", block_size); - rtc_manager.compile(kernel_label, - "cast_transpose_optimized_kernel", - code, - "transformer_engine/common/transpose/rtc/cast_transpose.cu"); - } - rtc_manager.launch(kernel_label, - num_blocks, block_size, 0, stream, - static_cast(input.data.dptr), - reinterpret_cast(noop.data.dptr), - static_cast(cast_output.data.dptr), - static_cast(transposed_output.data.dptr), - static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), - row_length, num_rows); - } else { // Statically-compiled general kernel - constexpr size_t load_size = 4; - constexpr size_t store_size = 4; - constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; - constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; - const int num_blocks = (DIVUP(row_length, row_tile_size) - * DIVUP(num_rows, col_tile_size)); - cast_transpose_general_kernel - <<>>( - static_cast(input.data.dptr), - reinterpret_cast(noop.data.dptr), - static_cast(cast_output.data.dptr), - static_cast(transposed_output.data.dptr), - static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), - row_length, num_rows); - } - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + cast_output.data.dtype, OutputType, + constexpr const char *itype_name = TypeInfo::name; + constexpr const char *otype_name = TypeInfo::name; + constexpr size_t itype_size = sizeof(InputType); + constexpr size_t otype_size = sizeof(OutputType); + + // Choose between runtime-compiled or statically-compiled kernel + const bool aligned = + (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); + if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size, + store_size); + }; + add_config(8, 8); + add_config(4, 8); + add_config(8, 4); + add_config(4, 4); + add_config(2, 8); + add_config(8, 2); + add_config(2, 4); + add_config(4, 2); + add_config(2, 2); + add_config(1, 8); + add_config(8, 1); + add_config(1, 4); + add_config(4, 1); + add_config(1, 2); + add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = + *std::min_element(kernel_configs.begin(), kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; + + // Compile NVRTC kernel if needed and launch + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings( + "cast_transpose" + ",itype=", + itype_name, ",otype=", otype_name, ",load_size=", load_size, + ",store_size=", store_size); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_cast_transpose_cu; + code = regex_replace(code, "__ITYPE__", itype_name); + code = regex_replace(code, "__OTYPE__", otype_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code, + "transformer_engine/common/transpose/rtc/cast_transpose.cu"); + } + rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(cast_output.data.dptr), + static_cast(transposed_output.data.dptr), + static_cast(cast_output.scale.dptr), + static_cast(cast_output.amax.dptr), row_length, num_rows); + } else { // Statically-compiled general kernel + constexpr size_t load_size = 4; + constexpr size_t store_size = 4; + constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; + constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; + const int num_blocks = + (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); + cast_transpose_general_kernel + <<>>( + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(cast_output.data.dptr), + static_cast(transposed_output.data.dptr), + static_cast(cast_output.scale.dptr), + static_cast(cast_output.amax.dptr), row_length, num_rows); + }); // NOLINT(*) + ); // NOLINT(*) } } // namespace transformer_engine -void nvte_cast_transpose(const NVTETensor input, - NVTETensor cast_output, - NVTETensor transposed_output, - cudaStream_t stream) { +void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output, + NVTETensor transposed_output, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose); using namespace transformer_engine; auto noop = Tensor(); - cast_transpose(*reinterpret_cast(input), - noop, - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - stream); + cast_transpose(*reinterpret_cast(input), noop, + reinterpret_cast(cast_output), + reinterpret_cast(transposed_output), stream); } -void nvte_cast_transpose_with_noop(const NVTETensor input, - const NVTETensor noop, - NVTETensor cast_output, - NVTETensor transposed_output, +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, + NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_with_noop); using namespace transformer_engine; - cast_transpose(*reinterpret_cast(input), - *reinterpret_cast(noop), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - stream); + cast_transpose(*reinterpret_cast(input), *reinterpret_cast(noop), + reinterpret_cast(cast_output), + reinterpret_cast(transposed_output), stream); } diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index c823753f6e2de5c401aaf1c2b3aa296b61c0f027..147ed3afa9fa63091fc2586cb3627494058ba9c4 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -4,16 +4,18 @@ * See LICENSE for license information. ************************************************************************/ -#include #include +#include + #include #include #include -#include "../utils.cuh" -#include "../util/rtc.h" -#include "../util/string.h" + #include "../common.h" #include "../util/math.h" +#include "../util/rtc.h" +#include "../util/string.h" +#include "../utils.cuh" namespace transformer_engine { @@ -26,7 +28,7 @@ namespace { constexpr size_t n_warps_per_tile = 8; constexpr size_t desired_load_size = 8; constexpr size_t desired_store_size = 8; -constexpr size_t desired_load_size_dact = 4; // dAct fusion kernels use more registers +constexpr size_t desired_load_size_dact = 4; // dAct fusion kernels use more registers constexpr size_t desired_store_size_dact = 4; constexpr size_t threads_per_warp = static_cast(THREADS_PER_WARP); @@ -38,443 +40,411 @@ static_assert(cast_transpose_num_threads <= max_threads_per_block); /* Performance heuristics for optimized kernel parameters */ struct KernelConfig { - size_t load_size = 0; // Vector load size - size_t store_size = 0; // Vector store size to transposed output - - bool valid = false; // Whether config is valid - bool is_dact = false; // Whether dact is used - size_t num_blocks = 0; // Number of CUDA blocks - - size_t active_sm_count = 0; // Number of active SMs - size_t elements_per_load = 0; // Elements per L1 cache load - size_t elements_per_load_dact = 0; // Elements per L1 cache load dact - size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output - size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output - - KernelConfig(size_t row_length, - size_t num_rows, - size_t itype_size, - size_t itype2_size, - size_t otype_size, - size_t load_size_, - size_t store_size_, - bool is_dact_) - : load_size{load_size_} - , store_size{store_size_} - , is_dact{is_dact_} { - if (is_dact) { - if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) { - return; - } - } - - // Check that tiles are correctly aligned - constexpr size_t cache_line_size = 128; - if (load_size % itype_size != 0 - || store_size % otype_size != 0 - || cache_line_size % itype_size != 0 - || cache_line_size % otype_size != 0) { - return; - } - /* row_tile_elements */ - const size_t tile_size_x = (load_size * THREADS_PER_WARP) / itype_size; - /* col_tile_elements */ - const size_t tile_size_y = (store_size * THREADS_PER_WARP) / otype_size; - const size_t num_tiles_x = row_length / tile_size_x; - const size_t num_tiles_y = num_rows / tile_size_y; - - valid = (row_length % tile_size_x == 0 && num_rows % tile_size_y == 0); - if (!valid) { - return; - } + size_t load_size = 0; // Vector load size + size_t store_size = 0; // Vector store size to transposed output + + bool valid = false; // Whether config is valid + bool is_dact = false; // Whether dact is used + size_t num_blocks = 0; // Number of CUDA blocks + + size_t active_sm_count = 0; // Number of active SMs + size_t elements_per_load = 0; // Elements per L1 cache load + size_t elements_per_load_dact = 0; // Elements per L1 cache load dact + size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output + size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output + + KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t itype2_size, + size_t otype_size, size_t load_size_, size_t store_size_, bool is_dact_) + : load_size{load_size_}, store_size{store_size_}, is_dact{is_dact_} { + if (is_dact) { + if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) { + return; + } + } - // Number of CUDA blocks - num_blocks = num_tiles_x * num_tiles_y; - - // Parameters for performance model - constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs - active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm), - static_cast(cuda::sm_count())); - elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size) - / itype_size); - elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size) - / itype2_size); - elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size) - / otype_size); - elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size) - / otype_size); + // Check that tiles are correctly aligned + constexpr size_t cache_line_size = 128; + if (load_size % itype_size != 0 || store_size % otype_size != 0 || + cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) { + return; + } + /* row_tile_elements */ + const size_t tile_size_x = (load_size * THREADS_PER_WARP) / itype_size; + /* col_tile_elements */ + const size_t tile_size_y = (store_size * THREADS_PER_WARP) / otype_size; + const size_t num_tiles_x = row_length / tile_size_x; + const size_t num_tiles_y = num_rows / tile_size_y; + + valid = (row_length % tile_size_x == 0 && num_rows % tile_size_y == 0); + if (!valid) { + return; } - /* Compare by estimated cost */ - bool operator<(const KernelConfig &other) const { - if (this->valid && other.valid) { - // cost ~ (1/elements_per_load - // + 1/elements_per_load_dact - // + 1/elements_per_store_c - // + 1/elements_per_store_t) / active_sms - // Note: Integer arithmetic ensures stable ordering - const auto &l1 = this->elements_per_load; - const auto &la1 = this->elements_per_load_dact; - const auto &sc1 = this->elements_per_store_c; - const auto &st1 = this->elements_per_store_t; - const auto &p1 = this->active_sm_count; - const auto &l2 = other.elements_per_load; - const auto &la2 = other.elements_per_load_dact; - const auto &sc2 = other.elements_per_store_c; - const auto &st2 = other.elements_per_store_t; - const auto &p2 = other.active_sm_count; - const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1); - const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1); - const auto scale = scale1 * scale2; - const auto cost1 = (scale/l1 + scale/sc1 + scale/st1 + (is_dact ? (scale / la1) : 0)) - / p1; - const auto cost2 = (scale/l2 + scale/sc2 + scale/st2 + (is_dact ? (scale / la2) : 0)) - / p2; - - return cost1 < cost2; - } else { - return this->valid && !other.valid; - } + // Number of CUDA blocks + num_blocks = num_tiles_x * num_tiles_y; + + // Parameters for performance model + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm), + static_cast(cuda::sm_count())); + elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size) / itype_size); + elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size) / itype2_size); + elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size) / otype_size); + elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size) / otype_size); + } + + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/elements_per_load + // + 1/elements_per_load_dact + // + 1/elements_per_store_c + // + 1/elements_per_store_t) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->elements_per_load; + const auto &la1 = this->elements_per_load_dact; + const auto &sc1 = this->elements_per_store_c; + const auto &st1 = this->elements_per_store_t; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.elements_per_load; + const auto &la2 = other.elements_per_load_dact; + const auto &sc2 = other.elements_per_store_c; + const auto &st2 = other.elements_per_store_t; + const auto &p2 = other.active_sm_count; + const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1); + const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1); + const auto scale = scale1 * scale2; + const auto cost1 = + (scale / l1 + scale / sc1 + scale / st1 + (is_dact ? (scale / la1) : 0)) / p1; + const auto cost2 = + (scale / l2 + scale / sc2 + scale / st2 + (is_dact ? (scale / la2) : 0)) / p2; + + return cost1 < cost2; + } else { + return this->valid && !other.valid; } + } }; - -template +template inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], OVec (&out_trans)[nvec_in], CVec &out_dbias, // NOLINT(*) typename OVec::type *output_cast_tile, - const size_t current_place, - const size_t stride, + const size_t current_place, const size_t stride, const CType scale, - CType &amax, // NOLINT(*) + CType &amax, // NOLINT(*) const int dbias_shfl_src_lane, const bool valid_store) { - using OType = typename OVec::type; - using OVecC = Vec; - - CVec step_dbias; - if constexpr (IS_DBIAS) { - step_dbias.clear(); + using OType = typename OVec::type; + using OVecC = Vec; + + CVec step_dbias; + if constexpr (IS_DBIAS) { + step_dbias.clear(); + } + +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + OVecC out_cast; +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + const CType tmp = in[i].data.elt[j]; + if constexpr (IS_DBIAS) { + step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation + } + out_cast.data.elt[j] = static_cast(tmp * scale); + out_trans[j].data.elt[i] = static_cast(tmp * scale); // thread tile transpose + + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(tmp), amax); } - - #pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - OVecC out_cast; - #pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - const CType tmp = in[i].data.elt[j]; - if constexpr (IS_DBIAS) { - step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation - } - out_cast.data.elt[j] = static_cast(tmp * scale); - out_trans[j].data.elt[i] = static_cast(tmp * scale); // thread tile transpose - - __builtin_assume(amax >= 0); - amax = fmaxf(fabsf(tmp), amax); - } - if (IS_FULL_TILE || valid_store) { - out_cast.store_to(output_cast_tile, current_place + stride * i); - } + if (IS_FULL_TILE || valid_store) { + out_cast.store_to(output_cast_tile, current_place + stride * i); } - - if constexpr (IS_DBIAS) { - #pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - CType elt = step_dbias.data.elt[j]; - elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp - out_dbias.data.elt[j] += elt; - } + } + + if constexpr (IS_DBIAS) { +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + CType elt = step_dbias.data.elt[j]; + elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp + out_dbias.data.elt[j] += elt; } + } } void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ - Tensor* workspace, - const int nvec_out) { - const size_t row_length = cast_output.data.shape[1]; - const size_t num_rows = cast_output.data.shape[0]; + Tensor *workspace, const int nvec_out) { + const size_t row_length = cast_output.data.shape[1]; + const size_t num_rows = cast_output.data.shape[0]; - const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); - NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); - const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); + const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); - workspace->data.shape = {num_rows_partial_dbias, row_length}; - workspace->data.dtype = DType::kFloat32; + workspace->data.shape = {num_rows_partial_dbias, row_length}; + workspace->data.dtype = DType::kFloat32; } -template -__global__ void -__launch_bounds__(reduce_dbias_num_threads) -reduce_dbias_kernel(OutputType* const dbias_output, - const ComputeType* const dbias_partial, - const int row_length, - const int num_rows) { - using ComputeVec = Vec; - using OutputVec = Vec; +template +__global__ void __launch_bounds__(reduce_dbias_num_threads) + reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial, + const int row_length, const int num_rows) { + using ComputeVec = Vec; + using OutputVec = Vec; - const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; - if (thread_id * nvec >= row_length) { - return; - } - - const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec; - OutputType* const thread_out_base = dbias_output + thread_id * nvec; + if (thread_id * nvec >= row_length) { + return; + } - const int stride_in_vec = row_length / nvec; + const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec; + OutputType *const thread_out_base = dbias_output + thread_id * nvec; - ComputeVec ldg_vec; - ComputeVec acc_vec; acc_vec.clear(); - for (int i = 0; i < num_rows; ++i) { - ldg_vec.load_from(thread_in_base, i * stride_in_vec); - #pragma unroll - for (int e = 0; e < nvec; ++e) { - acc_vec.data.elt[e] += ldg_vec.data.elt[e]; - } - } + const int stride_in_vec = row_length / nvec; - OutputVec stg_vec; - #pragma unroll + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < num_rows; ++i) { + ldg_vec.load_from(thread_in_base, i * stride_in_vec); +#pragma unroll for (int e = 0; e < nvec; ++e) { - stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]); + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; } - stg_vec.store_to(thread_out_base, 0); + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base, 0); } template -void reduce_dbias(const Tensor &workspace, - Tensor *dbias, - const size_t row_length, - const size_t num_rows, - const int nvec_out, - cudaStream_t stream) { - constexpr int reduce_dbias_store_bytes = 8; // stg.64 - constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType); - - NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape."); - - const size_t reduce_dbias_row_length = row_length; - const size_t reduce_dbias_num_rows = DIVUP(num_rows, - static_cast(nvec_out * THREADS_PER_WARP)); - const size_t reduce_dbias_num_blocks = DIVUP(row_length, - reduce_dbias_num_threads * reduce_dbias_nvec); - - using DbiasOutputType = fp32; - reduce_dbias_kernel - <<>> - (reinterpret_cast(dbias->data.dptr), - reinterpret_cast(workspace.data.dptr), - reduce_dbias_row_length, - reduce_dbias_num_rows); +void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length, + const size_t num_rows, const int nvec_out, cudaStream_t stream) { + constexpr int reduce_dbias_store_bytes = 8; // stg.64 + constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType); + + NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape."); + + const size_t reduce_dbias_row_length = row_length; + const size_t reduce_dbias_num_rows = + DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); + const size_t reduce_dbias_num_blocks = + DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec); + + using DbiasOutputType = fp32; + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), + reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, + reduce_dbias_num_rows); } - -template -__global__ void -__launch_bounds__(cast_transpose_num_threads) -cast_transpose_fused_kernel_notaligned(const Param param, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { - using IType = typename Param::InputType; - using IType2 = typename Param::InputType2; - using OType = typename Param::OutputType; - using CType = typename Param::ComputeType; - using IVec = Vec; - using IVec2 = Vec; - using OVec = Vec; - using CVec = Vec; - - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) - / (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) - + warp_id / n_warps_per_tile; - if (tile_id >= num_tiles) { - return; - } - - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const size_t tile_offset = (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) - * THREADS_PER_WARP; - const size_t tile_offset_transp = (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) - * THREADS_PER_WARP; - - const IType * const my_input_tile = param.input + tile_offset; - const IType2 * const my_act_input_tile = param.act_input + tile_offset; - OType * const my_output_c_tile = param.output_c + tile_offset; - OType * const my_output_t_tile = param.output_t + tile_offset_transp; - CType * const my_partial_dbias_tile = param.workspace - + (tile_id_x * (nvec_in * THREADS_PER_WARP) - + tile_id_y * row_length); - - const size_t stride = row_length / nvec_in; - const size_t output_stride = num_rows / nvec_out; - const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; - const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; - const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_length_rest; - const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_height_rest; - - OVec * const my_scratch = reinterpret_cast(scratch) - + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) - * (THREADS_PER_WARP + 1); - - CVec * const my_dbias_scratch = reinterpret_cast(scratch); - - IVec in[2][nvec_out]; - IVec2 act_in[2][nvec_out]; - const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; - constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; - OVec out_space[n_iterations][nvec_in]; - - size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) - % THREADS_PER_WARP; - CType amax = 0; - const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; - - CVec partial_dbias; - if constexpr (IS_DBIAS) { - partial_dbias.clear(); - } - - { - const bool valid_load = my_place < tile_length && - warp_id_in_tile * n_iterations < tile_height; - #pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - if (valid_load) { - const size_t ld_offset = current_stride + my_place + stride * i; - in[0][i].load_from(my_input_tile, ld_offset); - if constexpr (IS_DACT) { - act_in[0][i].load_from(my_act_input_tile, ld_offset); - } - } else { - in[0][i].clear(); - if constexpr (IS_DACT) { - act_in[0][i].clear(); - } - } - } - } - - #pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { - const bool valid_load = my_place_in < tile_length && - warp_id_in_tile * n_iterations + i + 1 < tile_height; - #pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - if (valid_load) { - const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j); - in[current_in][j].load_from(my_input_tile, ld_offset); - if constexpr (IS_DACT) { - act_in[current_in][j].load_from(my_act_input_tile, ld_offset); - } - } else { - in[current_in][j].clear(); - if constexpr (IS_DACT) { - act_in[current_in][j].clear(); - } - } - } +template +__global__ void __launch_bounds__(cast_transpose_num_threads) + cast_transpose_fused_kernel_notaligned(const Param param, const size_t row_length, + const size_t num_rows, const size_t num_tiles) { + using IType = typename Param::InputType; + using IType2 = typename Param::InputType2; + using OType = typename Param::OutputType; + using CType = typename Param::ComputeType; + using IVec = Vec; + using IVec2 = Vec; + using OVec = Vec; + using CVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = + (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) { + return; + } + + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const size_t tile_offset = + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + const size_t tile_offset_transp = + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + + const IType *const my_input_tile = param.input + tile_offset; + const IType2 *const my_act_input_tile = param.act_input + tile_offset; + OType *const my_output_c_tile = param.output_c + tile_offset; + OType *const my_output_t_tile = param.output_t + tile_offset_transp; + CType *const my_partial_dbias_tile = + param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length); + + const size_t stride = row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; + const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; + const unsigned int tile_length = + row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest; + const unsigned int tile_height = + row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest; + + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + CVec *const my_dbias_scratch = reinterpret_cast(scratch); + + IVec in[2][nvec_out]; + IVec2 act_in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space[n_iterations][nvec_in]; + + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + CType amax = 0; + const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; + + CVec partial_dbias; + if constexpr (IS_DBIAS) { + partial_dbias.clear(); + } + + { + const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height; +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + if (valid_load) { + const size_t ld_offset = current_stride + my_place + stride * i; + in[0][i].load_from(my_input_tile, ld_offset); + if constexpr (IS_DACT) { + act_in[0][i].load_from(my_act_input_tile, ld_offset); } - CVec after_dact[nvec_out]; // NOLINT(*) - #pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - #pragma unroll - for (unsigned int k = 0; k < nvec_in; ++k) { - if constexpr (IS_DACT) { - after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) - * OP(act_in[current_in ^ 1][j].data.elt[k], {}); - } else { - after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]); - } - } + } else { + in[0][i].clear(); + if constexpr (IS_DACT) { + act_in[0][i].clear(); } - const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations) - % THREADS_PER_WARP; - constexpr bool IS_FULL_TILE = false; - const bool valid_store = (my_place < tile_length) - && (warp_id_in_tile * n_iterations + i < tile_height); - - cast_and_transpose_regs - (after_dact, out_space[i], partial_dbias, my_output_c_tile, current_place, - stride, scale, amax, dbias_shfl_src_lane, valid_store); - - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += nvec_out * stride; - current_row += nvec_out; + } } - - for (unsigned int i = 0; i < nvec_in; ++i) { - #pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) - % THREADS_PER_WARP] = out_space[j][i]; + } + +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { + const bool valid_load = + my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height; +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { + if (valid_load) { + const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j); + in[current_in][j].load_from(my_input_tile, ld_offset); + if constexpr (IS_DACT) { + act_in[current_in][j].load_from(my_act_input_tile, ld_offset); + } + } else { + in[current_in][j].clear(); + if constexpr (IS_DACT) { + act_in[current_in][j].clear(); + } } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) - % THREADS_PER_WARP; - current_stride = i * output_stride - + warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { - const bool valid_store = my_place < tile_height; - if (valid_store) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, - current_stride + my_place); - } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; + } + } + CVec after_dact[nvec_out]; // NOLINT(*) +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { +#pragma unroll + for (unsigned int k = 0; k < nvec_in; ++k) { + if constexpr (IS_DACT) { + after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * + OP(act_in[current_in ^ 1][j].data.elt[k], {}); + } else { + after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]); } - __syncthreads(); + } } - - if constexpr (IS_DBIAS) { - my_dbias_scratch[threadIdx.x] = partial_dbias; - __syncthreads(); - if (warp_id_in_tile == 0) { - #pragma unroll - for (unsigned int i = 1; i < n_warps_per_tile; ++i) { - CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; - #pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - partial_dbias.data.elt[j] += tmp.data.elt[j]; - } - } - if (my_id_in_warp < tile_length) { - partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); - } + const int dbias_shfl_src_lane = + (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + constexpr bool IS_FULL_TILE = false; + const bool valid_store = + (my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height); + + cast_and_transpose_regs(after_dact, out_space[i], partial_dbias, + my_output_c_tile, current_place, stride, scale, + amax, dbias_shfl_src_lane, valid_store); + + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + current_row += nvec_out; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { + const bool valid_store = my_place < tile_height; + if (valid_store) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, + current_stride + my_place); + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); + } + + if constexpr (IS_DBIAS) { + my_dbias_scratch[threadIdx.x] = partial_dbias; + __syncthreads(); + if (warp_id_in_tile == 0) { +#pragma unroll + for (unsigned int i = 1; i < n_warps_per_tile; ++i) { + CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + partial_dbias.data.elt[j] += tmp.data.elt[j]; } + } + if (my_id_in_warp < tile_length) { + partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); + } } + } - /* warp tile amax reduce*/ - amax = reduce_max(amax, warp_id); + /* warp tile amax reduce*/ + amax = reduce_max(amax, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (param.amax != nullptr) { - atomicMaxFloat(param.amax, amax); - } + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + if (param.amax != nullptr) { + atomicMaxFloat(param.amax, amax); } + } } -static const char* ActTypeToString[] = { +static const char *ActTypeToString[] = { "NoAct", // 0 "Sigmoid", // 1 "GeLU", // 2 @@ -484,1021 +454,915 @@ static const char* ActTypeToString[] = { "SReLU" // 6 }; -template +template int get_dactivation_type() { - if (OP == &sigmoid) { - return 1; - } else if (OP == &dgelu) { - return 2; - } else if (OP == &dqgelu) { - return 3; - } else if (OP == &dsilu) { - return 4; - } else if (OP == &drelu) { - return 5; - } else if (OP == &dsrelu) { - return 6; - } else { - return 0; - } + if (OP == &sigmoid) { + return 1; + } else if (OP == &dgelu) { + return 2; + } else if (OP == &dqgelu) { + return 3; + } else if (OP == &dsilu) { + return 4; + } else if (OP == &drelu) { + return 5; + } else if (OP == &dsrelu) { + return 6; + } else { + return 0; + } } template -void cast_transpose_fused(const Tensor &input, - const Tensor &act_input, - Tensor *cast_output, - Tensor *transposed_output, - Tensor *dbias, - Tensor *workspace, + ComputeType (*OP)(ComputeType, const ParamOP &)> +void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *cast_output, + Tensor *transposed_output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - CheckInputTensor(input, "cast_transpose_fused_input"); - CheckOutputTensor(*cast_output, "cast_output"); - CheckOutputTensor(*transposed_output, "transposed_output"); - - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); - NVTE_CHECK(input.data.shape == cast_output->data.shape, - "Input and C output must have the same shape."); - const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; - - NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); - - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, - "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); - - if constexpr (IS_DBIAS) { - CheckOutputTensor(*dbias, "dbias"); - NVTE_CHECK(dbias->data.dtype == input.data.dtype, - "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{ row_length }, - "Wrong shape of DBias."); - } - if constexpr (IS_DACT) { - CheckInputTensor(act_input, "act_input"); - NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match."); - } - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, - using InputType2 = InputType; - using Param = CTDBiasDActParam; - - constexpr int itype_size = sizeof(InputType); - constexpr int itype2_size = sizeof(InputType2); - constexpr int otype_size = sizeof(OutputType); - - const bool aligned = (row_length % THREADS_PER_WARP == 0) - && (num_rows % THREADS_PER_WARP == 0); - const bool jit_compiled = aligned && rtc::is_enabled(); - - size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size); - size_t store_size = (IS_DACT ? desired_store_size_dact : desired_store_size); - size_t num_blocks; - - if (jit_compiled) { - // Pick kernel config - std::vector kernel_configs; - kernel_configs.reserve(16); - auto add_config = [&](size_t load_size_config, size_t store_size_config) { - kernel_configs.emplace_back(row_length, num_rows, - itype_size, itype2_size, otype_size, - load_size_config, store_size_config, - IS_DACT); - }; - add_config(8, 8); - add_config(4, 8); add_config(8, 4); - add_config(4, 4); - add_config(2, 8); add_config(8, 2); - add_config(2, 4); add_config(4, 2); - add_config(2, 2); - add_config(1, 8); add_config(8, 1); - add_config(1, 4); add_config(4, 1); - add_config(1, 2); add_config(2, 1); - add_config(1, 1); - - // Select the kernel configuration with the lowest cost - const auto &kernel_config = *std::min_element(kernel_configs.begin(), - kernel_configs.end()); - NVTE_CHECK(kernel_config.valid, "invalid kernel config"); - load_size = kernel_config.load_size; - store_size = kernel_config.store_size; - num_blocks = kernel_config.num_blocks; + CheckInputTensor(input, "cast_transpose_fused_input"); + CheckOutputTensor(*cast_output, "cast_output"); + CheckOutputTensor(*transposed_output, "transposed_output"); + + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); + NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); + NVTE_CHECK(input.data.shape == cast_output->data.shape, + "Input and C output must have the same shape."); + const size_t row_length = input.data.shape[1]; + const size_t num_rows = input.data.shape[0]; + + NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); + NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + + NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, + "C and T outputs need to have the same type."); + NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, + "C and T outputs need to share amax tensor."); + NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, + "C and T outputs need to share scale tensor."); + + if constexpr (IS_DBIAS) { + CheckOutputTensor(*dbias, "dbias"); + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{row_length}, "Wrong shape of DBias."); + } + if constexpr (IS_DACT) { + CheckInputTensor(act_input, "act_input"); + NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match."); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + cast_output->data.dtype, OutputType, using InputType2 = InputType; + using Param = CTDBiasDActParam; + + constexpr int itype_size = sizeof(InputType); + constexpr int itype2_size = sizeof(InputType2); + constexpr int otype_size = sizeof(OutputType); + + const bool aligned = + (row_length % THREADS_PER_WARP == 0) && (num_rows % THREADS_PER_WARP == 0); + const bool jit_compiled = aligned && rtc::is_enabled(); + + size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size); + size_t store_size = (IS_DACT ? desired_store_size_dact : desired_store_size); + size_t num_blocks; + + if (jit_compiled) { + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + auto add_config = [&](size_t load_size_config, size_t store_size_config) { + kernel_configs.emplace_back(row_length, num_rows, itype_size, itype2_size, otype_size, + load_size_config, store_size_config, IS_DACT); + }; + add_config(8, 8); + add_config(4, 8); + add_config(8, 4); + add_config(4, 4); + add_config(2, 8); + add_config(8, 2); + add_config(2, 4); + add_config(4, 2); + add_config(2, 2); + add_config(1, 8); + add_config(8, 1); + add_config(1, 4); + add_config(4, 1); + add_config(1, 2); + add_config(2, 1); + add_config(1, 1); + + // Select the kernel configuration with the lowest cost + const auto &kernel_config = + *std::min_element(kernel_configs.begin(), kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + load_size = kernel_config.load_size; + store_size = kernel_config.store_size; + num_blocks = kernel_config.num_blocks; + } + + const size_t nvec_in = load_size / itype_size; + const size_t nvec_out = store_size / otype_size; + const size_t tile_size_x = nvec_in * threads_per_warp; + const size_t tile_size_y = nvec_out * threads_per_warp; + const size_t num_tiles_x = DIVUP(row_length, tile_size_x); + const size_t num_tiles_y = DIVUP(num_rows, tile_size_y); + const size_t num_tiles = num_tiles_x * num_tiles_y; + + NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + + if (!jit_compiled) { + num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block); + } if constexpr (IS_DBIAS) { + if (workspace->data.dptr == nullptr) { + populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out); + return; } - - const size_t nvec_in = load_size / itype_size; - const size_t nvec_out = store_size / otype_size; - const size_t tile_size_x = nvec_in * threads_per_warp; - const size_t tile_size_y = nvec_out * threads_per_warp; - const size_t num_tiles_x = DIVUP(row_length, tile_size_x); - const size_t num_tiles_y = DIVUP(num_rows, tile_size_y); - const size_t num_tiles = num_tiles_x * num_tiles_y; - - NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); - NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); - - if (!jit_compiled) { - num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block); - } - if constexpr (IS_DBIAS) { - if (workspace->data.dptr == nullptr) { - populate_cast_transpose_dbias_workspace_config(*cast_output, - workspace, nvec_out); - return; - } - } - - size_t VecOutputTypeSize; - switch (nvec_out) { - case 1: VecOutputTypeSize = sizeof(Vec); break; - case 2: VecOutputTypeSize = sizeof(Vec); break; - case 4: VecOutputTypeSize = sizeof(Vec); break; - case 8: VecOutputTypeSize = sizeof(Vec); break; + } + + size_t VecOutputTypeSize; + switch (nvec_out) { + case 1: + VecOutputTypeSize = sizeof(Vec); + break; + case 2: + VecOutputTypeSize = sizeof(Vec); + break; + case 4: + VecOutputTypeSize = sizeof(Vec); + break; + case 8: + VecOutputTypeSize = sizeof(Vec); + break; + } size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile * + (threads_per_warp + 1) * VecOutputTypeSize; + + if constexpr (IS_DBIAS) { + size_t VecComputeTypeSize; + switch (nvec_in) { + case 1: + VecComputeTypeSize = sizeof(Vec); + break; + case 2: + VecComputeTypeSize = sizeof(Vec); + break; + case 4: + VecComputeTypeSize = sizeof(Vec); + break; + case 8: + VecComputeTypeSize = sizeof(Vec); + break; } - size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile - * (threads_per_warp + 1) * VecOutputTypeSize; - - if constexpr (IS_DBIAS) { - size_t VecComputeTypeSize; - switch (nvec_in) { - case 1: VecComputeTypeSize = sizeof(Vec); break; - case 2: VecComputeTypeSize = sizeof(Vec); break; - case 4: VecComputeTypeSize = sizeof(Vec); break; - case 8: VecComputeTypeSize = sizeof(Vec); break; - } - const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize; - if (shared_size_transpose < shared_size_dbias) { - shared_size_transpose = shared_size_dbias; - } - } - - Param param; - param.input = reinterpret_cast(input.data.dptr); - param.output_c = reinterpret_cast(cast_output->data.dptr); - param.output_t = reinterpret_cast(transposed_output->data.dptr); - param.scale_ptr = reinterpret_cast(transposed_output->scale.dptr); - param.amax = reinterpret_cast(transposed_output->amax.dptr); - param.scale_inv = reinterpret_cast(cast_output->scale_inv.dptr); - if constexpr (IS_DBIAS) { - param.workspace = reinterpret_cast(workspace->data.dptr); + const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize; + if (shared_size_transpose < shared_size_dbias) { + shared_size_transpose = shared_size_dbias; } + } + + Param param; + param.input = reinterpret_cast(input.data.dptr); + param.output_c = reinterpret_cast(cast_output->data.dptr); + param.output_t = reinterpret_cast(transposed_output->data.dptr); + param.scale_ptr = reinterpret_cast(transposed_output->scale.dptr); + param.amax = reinterpret_cast(transposed_output->amax.dptr); + param.scale_inv = reinterpret_cast(cast_output->scale_inv.dptr); + if constexpr (IS_DBIAS) { + param.workspace = reinterpret_cast(workspace->data.dptr); + } if constexpr (IS_DACT) { + param.act_input = reinterpret_cast(act_input.data.dptr); + } + + // Runtime-compiled tuned kernel + if (jit_compiled) { + constexpr const char *itype_name = TypeInfo::name; + constexpr const char *itype2_name = TypeInfo::name; + constexpr const char *otype_name = TypeInfo::name; + + int dActType = 0; if constexpr (IS_DACT) { - param.act_input = reinterpret_cast(act_input.data.dptr); + dActType = get_dactivation_type(); } - // Runtime-compiled tuned kernel - if (jit_compiled) { - constexpr const char *itype_name = TypeInfo::name; - constexpr const char *itype2_name = TypeInfo::name; - constexpr const char *otype_name = TypeInfo::name; - - int dActType = 0; - if constexpr (IS_DACT) { - dActType = get_dactivation_type(); - } - - // Compile NVRTC kernel if needed and launch - auto& rtc_manager = rtc::KernelManager::instance(); - const std::string kernel_label = - concat_strings("cast_transpose_fusion" - ",itype=", itype_name, - ",itype2=", itype2_name, - ",otype=", otype_name, - ",load_size=", load_size, - ",store_size=", store_size, - ",IS_DBIAS=", IS_DBIAS, - ",IS_DACT=", IS_DACT, - ",dactivationType=", ActTypeToString[dActType]); - - if (!rtc_manager.is_compiled(kernel_label)) { - std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu; - code = regex_replace(code, "__ITYPE__", itype_name); - code = regex_replace(code, "__ITYPE2__", itype2_name); - code = regex_replace(code, "__OTYPE__", otype_name); - code = regex_replace(code, "__LOAD_SIZE__", load_size); - code = regex_replace(code, "__STORE_SIZE__", store_size); - code = regex_replace(code, "__WARPS_PER_TILE__", n_warps_per_tile); - code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads); - code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS); - code = regex_replace(code, "__IS_DACT__", IS_DACT); - code = regex_replace(code, "__DACTIVATION_TYPE__", dActType); - - rtc_manager.compile( - kernel_label, - "cast_transpose_fusion_kernel_optimized", - code, - "transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu"); - } - - rtc_manager.set_cache_config(kernel_label, CU_FUNC_CACHE_PREFER_SHARED); - - rtc_manager.launch(kernel_label, - num_blocks, cast_transpose_num_threads, shared_size_transpose, stream, - param, row_length, num_rows, num_tiles); - } else { // Statically-compiled general kernel - constexpr size_t load_size = IS_DACT ? desired_load_size_dact : - desired_load_size; - constexpr size_t store_size = IS_DACT ? desired_store_size_dact : - desired_store_size; - constexpr size_t nvec_in = load_size / itype_size; - constexpr size_t nvec_out = store_size / otype_size; - - NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); - NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); - - cudaFuncSetAttribute( - cast_transpose_fused_kernel_notaligned - , - cudaFuncAttributePreferredSharedMemoryCarveout, - 100); - cast_transpose_fused_kernel_notaligned - - <<>> - (param, row_length, num_rows, num_tiles); + // Compile NVRTC kernel if needed and launch + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings( + "cast_transpose_fusion" + ",itype=", + itype_name, ",itype2=", itype2_name, ",otype=", otype_name, + ",load_size=", load_size, ",store_size=", store_size, ",IS_DBIAS=", IS_DBIAS, + ",IS_DACT=", IS_DACT, ",dactivationType=", ActTypeToString[dActType]); + + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu; + code = regex_replace(code, "__ITYPE__", itype_name); + code = regex_replace(code, "__ITYPE2__", itype2_name); + code = regex_replace(code, "__OTYPE__", otype_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", n_warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads); + code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS); + code = regex_replace(code, "__IS_DACT__", IS_DACT); + code = regex_replace(code, "__DACTIVATION_TYPE__", dActType); + + rtc_manager.compile( + kernel_label, "cast_transpose_fusion_kernel_optimized", code, + "transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu"); } - if constexpr (IS_DBIAS) { - reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, stream); - } - ); // NOLINT(*) - ); // NOLINT(*) + rtc_manager.set_cache_config(kernel_label, CU_FUNC_CACHE_PREFER_SHARED); + + rtc_manager.launch(kernel_label, num_blocks, cast_transpose_num_threads, + shared_size_transpose, stream, param, row_length, num_rows, + num_tiles); + } else { // Statically-compiled general kernel + constexpr size_t load_size = IS_DACT ? desired_load_size_dact : desired_load_size; + constexpr size_t store_size = IS_DACT ? desired_store_size_dact : desired_store_size; + constexpr size_t nvec_in = load_size / itype_size; + constexpr size_t nvec_out = store_size / otype_size; + + NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + + cudaFuncSetAttribute( + cast_transpose_fused_kernel_notaligned, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cast_transpose_fused_kernel_notaligned + <<>>( + param, row_length, num_rows, num_tiles); + } + + if constexpr (IS_DBIAS) { + reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, stream); + }); // NOLINT(*) + ); // NOLINT(*) } -template -__global__ void -__launch_bounds__(cast_transpose_num_threads) -dgated_act_cast_transpose_kernel(const IType * const input, - const IType * const act_input, - OType * const output_c, - OType * const output_t, - const CType * const scale_ptr, - CType * const amax, - CType * const scale_inv, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { - using IVec = Vec; - using OVec = Vec; - using CVec = Vec; - - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) - + warp_id / n_warps_per_tile; - if (tile_id >= num_tiles) { - return; +template +__global__ void __launch_bounds__(cast_transpose_num_threads) + dgated_act_cast_transpose_kernel(const IType *const input, const IType *const act_input, + OType *const output_c, OType *const output_t, + const CType *const scale_ptr, CType *const amax, + CType *const scale_inv, const size_t row_length, + const size_t num_rows, const size_t num_tiles) { + using IVec = Vec; + using OVec = Vec; + using CVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) { + return; + } + + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const IType *const my_input_tile = + input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + const IType *const my_act_input_tile = + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP; + const IType *const my_gate_input_tile = + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP + + row_length; + OType *const my_output_c_tile_0 = + output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP; + OType *const my_output_c_tile_1 = + output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP + + row_length; + OType *const my_output_t_tile_0 = + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + OType *const my_output_t_tile_1 = + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP + + row_length * num_rows; + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + IVec in[2][nvec_out]; + IVec act_in[2][nvec_out]; + IVec gate_in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space_0[n_iterations][nvec_in]; + OVec out_space_1[n_iterations][nvec_in]; + + const size_t stride = row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + const size_t stride2 = 2 * row_length / nvec_in; + size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; + CType max = 0; + const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; + + CVec partial_dbias; + +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i); + gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i); + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride2 + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { + in[current_in][j].load_from(my_input_tile, + current_stride + my_place_in + stride * (nvec_out + j)); + act_in[current_in][j].load_from(my_act_input_tile, + current_stride2 + my_place_in + stride2 * (nvec_out + j)); + gate_in[current_in][j].load_from(my_gate_input_tile, + current_stride2 + my_place_in + stride2 * (nvec_out + j)); + } + } + CVec after_dact[nvec_out]; // NOLINT(*) + CVec after_dgate[nvec_out]; // NOLINT(*) +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { +#pragma unroll + for (unsigned int k = 0; k < nvec_in; ++k) { + after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * + CType(in[current_in ^ 1][j].data.elt[k]) * + CType(gate_in[current_in ^ 1][j].data.elt[k]); + after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * + OP2(act_in[current_in ^ 1][j].data.elt[k], {}); + } + } + OVec out_trans_0[nvec_in]; // NOLINT(*) + OVec out_trans_1[nvec_in]; // NOLINT(*) + + constexpr bool IS_DBIAS = false; + constexpr bool IS_FULL_TILE = true; + constexpr bool valid_store = true; + constexpr int dbias_shfl_src_lane = 0; + + cast_and_transpose_regs(after_dact, out_trans_0, partial_dbias, + my_output_c_tile_0, current_place, stride2, + scale, max, dbias_shfl_src_lane, valid_store); + + cast_and_transpose_regs(after_dgate, out_trans_1, partial_dbias, + my_output_c_tile_1, current_place, stride2, + scale, max, dbias_shfl_src_lane, valid_store); + +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + out_space_0[i][j].data.vec = out_trans_0[j].data.vec; + out_space_1[i][j].data.vec = out_trans_1[j].data.vec; + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + current_stride2 += nvec_out * stride2; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space_0[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0, + current_stride + my_place); + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space_1[j][i]; } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1, + current_stride + my_place); + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); + } + + /* warp tile amax reduce*/ + max = reduce_max(max, warp_id); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + if (amax != nullptr) { + atomicMaxFloat(amax, max); + } + if (scale_inv != nullptr) { + reciprocal(scale_inv, scale); + } + } +} - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const IType * const my_input_tile = input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in + - tile_id_y * row_length * 2 * nvec_out) * - THREADS_PER_WARP; - const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in + - tile_id_y * row_length * 2 * nvec_out) * - THREADS_PER_WARP + row_length; - OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * 2 * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * 2 * nvec_out) * - THREADS_PER_WARP + row_length; - OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP + row_length * num_rows; - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); - - IVec in[2][nvec_out]; - IVec act_in[2][nvec_out]; - IVec gate_in[2][nvec_out]; - const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; - constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; - OVec out_space_0[n_iterations][nvec_in]; - OVec out_space_1[n_iterations][nvec_in]; - - const size_t stride = row_length / nvec_in; - const size_t output_stride = num_rows / nvec_out; - size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - const size_t stride2 = 2 * row_length / nvec_in; - size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; - CType max = 0; - const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; - - CVec partial_dbias; - - #pragma unroll +template +__global__ void __launch_bounds__(cast_transpose_num_threads) + dgated_act_cast_transpose_kernel_notaligned(const IType *const input, + const IType *const act_input, OType *const output_c, + OType *const output_t, const CType *const scale_ptr, + CType *const amax, CType *const scale_inv, + const size_t row_length, const size_t num_rows, + const size_t num_tiles) { + using IVec = Vec; + using OVec = Vec; + using CVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = + (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) return; + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const IType *const my_input_tile = + input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + const IType *const my_act_input_tile = + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP; + const IType *const my_gate_input_tile = + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP + + row_length; + OType *const my_output_c_tile_0 = + output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP; + OType *const my_output_c_tile_1 = + output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP + + row_length; + OType *const my_output_t_tile_0 = + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + OType *const my_output_t_tile_1 = + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP + + row_length * num_rows; + const size_t stride = row_length / nvec_in; + const size_t stride2 = 2 * row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; + const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; + const unsigned int tile_length = + row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest; + const unsigned int tile_height = + row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest; + + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + IVec in[2][nvec_out]; + IVec act_in[2][nvec_out]; + IVec gate_in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space_0[n_iterations][nvec_in]; + OVec out_space_1[n_iterations][nvec_in]; + + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + CType max = 0; + const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; + + CVec partial_dbias; + + { + const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height; +#pragma unroll for (unsigned int i = 0; i < nvec_out; ++i) { + if (valid_load) { in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i); gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i); + } else { + in[0][i].clear(); + act_in[0][i].clear(); + gate_in[0][i].clear(); + } } - #pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride2 + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { - #pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - in[current_in][j].load_from(my_input_tile, - current_stride + my_place_in + stride * (nvec_out + j)); - act_in[current_in][j].load_from(my_act_input_tile, - current_stride2 + my_place_in + stride2 * (nvec_out + j)); - gate_in[current_in][j].load_from(my_gate_input_tile, - current_stride2 + my_place_in + stride2 * (nvec_out + j)); - } - } - CVec after_dact[nvec_out]; // NOLINT(*) - CVec after_dgate[nvec_out]; // NOLINT(*) - #pragma unroll + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride2 + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { + { + const bool valid_load = + my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height; +#pragma unroll for (unsigned int j = 0; j < nvec_out; ++j) { - #pragma unroll - for (unsigned int k = 0; k < nvec_in; ++k) { - after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * - CType(in[current_in ^ 1][j].data.elt[k]) * - CType(gate_in[current_in ^ 1][j].data.elt[k]); - after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * - OP2(act_in[current_in ^ 1][j].data.elt[k], {}); - } - } - OVec out_trans_0[nvec_in]; // NOLINT(*) - OVec out_trans_1[nvec_in]; // NOLINT(*) - - constexpr bool IS_DBIAS = false; - constexpr bool IS_FULL_TILE = true; - constexpr bool valid_store = true; - constexpr int dbias_shfl_src_lane = 0; - - cast_and_transpose_regs - (after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2, - scale, max, dbias_shfl_src_lane, valid_store); - - cast_and_transpose_regs - (after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2, - scale, max, dbias_shfl_src_lane, valid_store); - - #pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - out_space_0[i][j].data.vec = out_trans_0[j].data.vec; - out_space_1[i][j].data.vec = out_trans_1[j].data.vec; + if (valid_load) { + in[current_in][j].load_from(my_input_tile, + current_stride + my_place_in + stride * (nvec_out + j)); + act_in[current_in][j].load_from( + my_act_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j)); + gate_in[current_in][j].load_from( + my_gate_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j)); + } else { + in[current_in][j].clear(); + act_in[current_in][j].clear(); + gate_in[current_in][j].clear(); + } } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += nvec_out * stride; - current_stride2 += nvec_out * stride2; + } } - - for (unsigned int i = 0; i < nvec_in; ++i) { - #pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i]; - } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0, - current_stride + my_place); - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; - } - __syncthreads(); - #pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i]; - } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1, - current_stride + my_place); - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; - } - __syncthreads(); + CVec after_dact[nvec_out]; // NOLINT(*) + CVec after_dgate[nvec_out]; // NOLINT(*) +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { +#pragma unroll + for (unsigned int k = 0; k < nvec_in; ++k) { + after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * + CType(in[current_in ^ 1][j].data.elt[k]) * + CType(gate_in[current_in ^ 1][j].data.elt[k]); + after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * + OP2(act_in[current_in ^ 1][j].data.elt[k], {}); + } } - - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) { - atomicMaxFloat(amax, max); - } - if (scale_inv != nullptr) { - reciprocal(scale_inv, scale); - } + OVec out_trans_0[nvec_in]; // NOLINT(*) + OVec out_trans_1[nvec_in]; // NOLINT(*) + + constexpr bool IS_DBIAS = false; + constexpr bool IS_FULL_TILE = false; + constexpr int dbias_shfl_src_lane = 0; + const bool valid_store = + (my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height); + + cast_and_transpose_regs(after_dact, out_trans_0, partial_dbias, + my_output_c_tile_0, current_place, stride2, + scale, max, dbias_shfl_src_lane, valid_store); + cast_and_transpose_regs(after_dgate, out_trans_1, partial_dbias, + my_output_c_tile_1, current_place, stride2, + scale, max, dbias_shfl_src_lane, valid_store); + +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + out_space_0[i][j].data.vec = out_trans_0[j].data.vec; + out_space_1[i][j].data.vec = out_trans_1[j].data.vec; } -} - -template -__global__ void -__launch_bounds__(cast_transpose_num_threads) -dgated_act_cast_transpose_kernel_notaligned(const IType * const input, - const IType * const act_input, - OType * const output_c, - OType * const output_t, - const CType * const scale_ptr, - CType * const amax, - CType * const scale_inv, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { - using IVec = Vec; - using OVec = Vec; - using CVec = Vec; - - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / - (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + - warp_id / n_warps_per_tile; - if (tile_id >= num_tiles) return; - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const IType * const my_input_tile = input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in + - tile_id_y * row_length * 2 * nvec_out) * - THREADS_PER_WARP; - const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in + - tile_id_y * row_length * 2 * nvec_out) * - THREADS_PER_WARP + row_length; - OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * 2 * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * 2 * nvec_out) * - THREADS_PER_WARP + row_length; - OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP + row_length * num_rows; - const size_t stride = row_length / nvec_in; - const size_t stride2 = 2 * row_length / nvec_in; - const size_t output_stride = num_rows / nvec_out; - const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; - const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; - const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_length_rest; - const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_height_rest; - - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); - - IVec in[2][nvec_out]; - IVec act_in[2][nvec_out]; - IVec gate_in[2][nvec_out]; - const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; - constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; - OVec out_space_0[n_iterations][nvec_in]; - OVec out_space_1[n_iterations][nvec_in]; - - size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - CType max = 0; - const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; - - CVec partial_dbias; - - { - const bool valid_load = my_place < tile_length && - warp_id_in_tile * n_iterations < tile_height; - #pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - if (valid_load) { - in[0][i].load_from(my_input_tile, - current_stride + my_place + stride * i); - act_in[0][i].load_from(my_act_input_tile, - current_stride2 + my_place + stride2 * i); - gate_in[0][i].load_from(my_gate_input_tile, - current_stride2 + my_place + stride2 * i); - } else { - in[0][i].clear(); - act_in[0][i].clear(); - gate_in[0][i].clear(); - } - } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + current_stride2 += nvec_out * stride2; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space_0[j][i]; } - #pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride2 + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { - { - const bool valid_load = my_place_in < tile_length && - warp_id_in_tile * n_iterations + i + 1 < tile_height; - #pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - if (valid_load) { - in[current_in][j].load_from(my_input_tile, - current_stride + my_place_in + stride * (nvec_out + j)); - act_in[current_in][j].load_from(my_act_input_tile, - current_stride2 + my_place_in + stride2 * (nvec_out + j)); - gate_in[current_in][j].load_from(my_gate_input_tile, - current_stride2 + my_place_in + stride2 * (nvec_out + j)); - } else { - in[current_in][j].clear(); - act_in[current_in][j].clear(); - gate_in[current_in][j].clear(); - } - } - } - } - CVec after_dact[nvec_out]; // NOLINT(*) - CVec after_dgate[nvec_out]; // NOLINT(*) - #pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - #pragma unroll - for (unsigned int k = 0; k < nvec_in; ++k) { - after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * - CType(in[current_in ^ 1][j].data.elt[k]) * - CType(gate_in[current_in ^ 1][j].data.elt[k]); - after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * - OP2(act_in[current_in ^ 1][j].data.elt[k], {}); - } - } - OVec out_trans_0[nvec_in]; // NOLINT(*) - OVec out_trans_1[nvec_in]; // NOLINT(*) - - constexpr bool IS_DBIAS = false; - constexpr bool IS_FULL_TILE = false; - constexpr int dbias_shfl_src_lane = 0; - const bool valid_store = (my_place < tile_length) - && (warp_id_in_tile * n_iterations + i < tile_height); - - cast_and_transpose_regs - (after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2, - scale, max, dbias_shfl_src_lane, valid_store); - cast_and_transpose_regs - (after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2, - scale, max, dbias_shfl_src_lane, valid_store); - - #pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - out_space_0[i][j].data.vec = out_trans_0[j].data.vec; - out_space_1[i][j].data.vec = out_trans_1[j].data.vec; - } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += nvec_out * stride; - current_stride2 += nvec_out * stride2; + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { + const bool valid_store = my_place < tile_height; + if (valid_store) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0, + current_stride + my_place); + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; } - - for (unsigned int i = 0; i < nvec_in; ++i) { - #pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i]; - } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { - const bool valid_store = my_place < tile_height; - if (valid_store) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0, - current_stride + my_place); - } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; - } - __syncthreads(); - #pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i]; - } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { - const bool valid_store = my_place < tile_height; - if (valid_store) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1, - current_stride + my_place); - } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; - } - __syncthreads(); + __syncthreads(); +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space_1[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { + const bool valid_store = my_place < tile_height; + if (valid_store) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1, + current_stride + my_place); + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; } + __syncthreads(); + } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + /* warp tile amax reduce*/ + max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) { - atomicMaxFloat(amax, max); - } - if (scale_inv != nullptr) { - reciprocal(scale_inv, scale); - } + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + if (amax != nullptr) { + atomicMaxFloat(amax, max); + } + if (scale_inv != nullptr) { + reciprocal(scale_inv, scale); } + } } -template -void dgated_act_cast_transpose(const Tensor &input, - const Tensor &gated_act_input, - Tensor *cast_output, - Tensor *transposed_output, - cudaStream_t stream) { - CheckInputTensor(input, "dgated_act_cast_transpose_input"); - CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input"); - CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output"); - CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output"); - - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); - const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; - - NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output."); - NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output."); - NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output."); - NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output."); - NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); - - NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match."); - - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, - "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); - NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr, - "C and T outputs need to share scale inverse tensor."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, - using InputType2 = InputType; - /* dact fusion kernel uses more registers */ - constexpr int desired_load_size_dact = 4; - constexpr int desired_store_size_dact = 4; - constexpr int itype_size = sizeof(InputType); - constexpr int otype_size = sizeof(OutputType); - constexpr int nvec_in = desired_load_size_dact / itype_size; - constexpr int nvec_out = desired_store_size_dact / otype_size; - - NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); - NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); - const size_t n_tiles = - DIVUP(row_length, static_cast(nvec_in * THREADS_PER_WARP)) * - DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); - const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; - const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); - - const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && - num_rows % (nvec_out * THREADS_PER_WARP) == 0; - const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile * - (THREADS_PER_WARP + 1) * sizeof(Vec); - if (full_tile) { - cudaFuncSetAttribute( - dgated_act_cast_transpose_kernel - , - cudaFuncAttributePreferredSharedMemoryCarveout, - 100); - - dgated_act_cast_transpose_kernel - - <<>>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(gated_act_input.data.dptr), - reinterpret_cast(cast_output->data.dptr), - reinterpret_cast(transposed_output->data.dptr), - reinterpret_cast(cast_output->scale.dptr), - reinterpret_cast(cast_output->amax.dptr), - reinterpret_cast(cast_output->scale_inv.dptr), - row_length, num_rows, n_tiles); - } else { - cudaFuncSetAttribute( - dgated_act_cast_transpose_kernel_notaligned - , - cudaFuncAttributePreferredSharedMemoryCarveout, - 100); - dgated_act_cast_transpose_kernel_notaligned - - <<>>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(gated_act_input.data.dptr), - reinterpret_cast(cast_output->data.dptr), - reinterpret_cast(transposed_output->data.dptr), - reinterpret_cast(cast_output->scale.dptr), - reinterpret_cast(cast_output->amax.dptr), - reinterpret_cast(cast_output->scale_inv.dptr), - row_length, num_rows, n_tiles); - } - ); // NOLINT(*) - ); // NOLINT(*) +template +void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, + Tensor *cast_output, Tensor *transposed_output, + cudaStream_t stream) { + CheckInputTensor(input, "dgated_act_cast_transpose_input"); + CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input"); + CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output"); + CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output"); + + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); + NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); + const size_t row_length = input.data.shape[1]; + const size_t num_rows = input.data.shape[0]; + + NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output."); + NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output."); + NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output."); + NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output."); + NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output."); + NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + + NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match."); + + NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, + "C and T outputs need to have the same type."); + NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, + "C and T outputs need to share amax tensor."); + NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, + "C and T outputs need to share scale tensor."); + NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr, + "C and T outputs need to share scale inverse tensor."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + cast_output->data.dtype, OutputType, using InputType2 = InputType; + /* dact fusion kernel uses more registers */ + constexpr int desired_load_size_dact = 4; + constexpr int desired_store_size_dact = 4; constexpr int itype_size = sizeof(InputType); + constexpr int otype_size = sizeof(OutputType); + constexpr int nvec_in = desired_load_size_dact / itype_size; + constexpr int nvec_out = desired_store_size_dact / otype_size; + + NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + const size_t n_tiles = + DIVUP(row_length, static_cast(nvec_in * THREADS_PER_WARP)) * + DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); + const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; + const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); + + const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && + num_rows % (nvec_out * THREADS_PER_WARP) == 0; + const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile * + (THREADS_PER_WARP + 1) * sizeof(Vec); + if (full_tile) { + cudaFuncSetAttribute( + dgated_act_cast_transpose_kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + dgated_act_cast_transpose_kernel + <<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(gated_act_input.data.dptr), + reinterpret_cast(cast_output->data.dptr), + reinterpret_cast(transposed_output->data.dptr), + reinterpret_cast(cast_output->scale.dptr), + reinterpret_cast(cast_output->amax.dptr), + reinterpret_cast(cast_output->scale_inv.dptr), row_length, num_rows, + n_tiles); + } else { + cudaFuncSetAttribute( + dgated_act_cast_transpose_kernel_notaligned, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + dgated_act_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(gated_act_input.data.dptr), + reinterpret_cast(cast_output->data.dptr), + reinterpret_cast(transposed_output->data.dptr), + reinterpret_cast(cast_output->scale.dptr), + reinterpret_cast(cast_output->amax.dptr), + reinterpret_cast(cast_output->scale_inv.dptr), row_length, num_rows, + n_tiles); + }); // NOLINT(*) + ); // NOLINT(*) } -} // namespace +} // namespace } // namespace transformer_engine - using ComputeType = typename transformer_engine::fp32; -void nvte_cast_transpose_dbias(const NVTETensor input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output, + NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_cast_transpose_dbias); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = false; - - constexpr const NVTETensor activation_input = nullptr; - - cast_transpose_fused( - *reinterpret_cast(input), - *reinterpret_cast(activation_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - reinterpret_cast(dbias), - reinterpret_cast(workspace), - stream); + NVTE_API_CALL(nvte_cast_transpose_dbias); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + + constexpr const NVTETensor activation_input = nullptr; + + cast_transpose_fused( + *reinterpret_cast(input), *reinterpret_cast(activation_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + reinterpret_cast(dbias), reinterpret_cast(workspace), stream); } -void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, - const NVTETensor act_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - constexpr auto dActivation = &dgelu; - - cast_transpose_fused( - *reinterpret_cast(input), - *reinterpret_cast(act_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - reinterpret_cast(dbias), - reinterpret_cast(workspace), - stream); +void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + constexpr auto dActivation = &dgelu; + + cast_transpose_fused( + *reinterpret_cast(input), *reinterpret_cast(act_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + reinterpret_cast(dbias), reinterpret_cast(workspace), stream); } -void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, - const NVTETensor silu_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - constexpr auto dActivation = &dsilu; - - cast_transpose_fused( - *reinterpret_cast(input), - *reinterpret_cast(silu_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - reinterpret_cast(dbias), - reinterpret_cast(workspace), - stream); +void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + constexpr auto dActivation = &dsilu; + + cast_transpose_fused( + *reinterpret_cast(input), *reinterpret_cast(silu_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + reinterpret_cast(dbias), reinterpret_cast(workspace), stream); } -void nvte_cast_transpose_dbias_drelu(const NVTETensor input, - const NVTETensor relu_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_cast_transpose_dbias_drelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - constexpr auto dActivation = &drelu; - - cast_transpose_fused( - *reinterpret_cast(input), - *reinterpret_cast(relu_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - reinterpret_cast(dbias), - reinterpret_cast(workspace), - stream); +void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + constexpr auto dActivation = &drelu; + + cast_transpose_fused( + *reinterpret_cast(input), *reinterpret_cast(relu_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + reinterpret_cast(dbias), reinterpret_cast(workspace), stream); } -void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, - const NVTETensor srelu_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - constexpr auto dActivation = &dsrelu; - - cast_transpose_fused( - *reinterpret_cast(input), - *reinterpret_cast(srelu_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - reinterpret_cast(dbias), - reinterpret_cast(workspace), - stream); +void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + constexpr auto dActivation = &dsrelu; + + cast_transpose_fused( + *reinterpret_cast(input), *reinterpret_cast(srelu_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + reinterpret_cast(dbias), reinterpret_cast(workspace), stream); } -void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, - const NVTETensor qgelu_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - constexpr auto dActivation = &dqgelu; - - cast_transpose_fused( - *reinterpret_cast(input), - *reinterpret_cast(qgelu_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - reinterpret_cast(dbias), - reinterpret_cast(workspace), - stream); +void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input, + NVTETensor cast_output, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + constexpr auto dActivation = &dqgelu; + + cast_transpose_fused( + *reinterpret_cast(input), *reinterpret_cast(qgelu_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + reinterpret_cast(dbias), reinterpret_cast(workspace), stream); } -void nvte_dgeglu_cast_transpose(const NVTETensor input, - const NVTETensor gated_act_input, - NVTETensor cast_output, - NVTETensor transposed_output, +void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, + NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream) { - NVTE_API_CALL(nvte_dgeglu_cast_transpose); - using namespace transformer_engine; - - constexpr auto dActivation = &dgelu; - constexpr auto Activation = &gelu; - - dgated_act_cast_transpose( - *reinterpret_cast(input), - *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - stream); + NVTE_API_CALL(nvte_dgeglu_cast_transpose); + using namespace transformer_engine; + + constexpr auto dActivation = &dgelu; + constexpr auto Activation = &gelu; + + dgated_act_cast_transpose( + *reinterpret_cast(input), *reinterpret_cast(gated_act_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + stream); } -void nvte_dswiglu_cast_transpose(const NVTETensor input, - const NVTETensor swiglu_input, - NVTETensor cast_output, - NVTETensor transposed_output, +void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input, + NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream) { - NVTE_API_CALL(nvte_dswiglu_cast_transpose); - using namespace transformer_engine; - - constexpr auto dActivation = &dsilu; - constexpr auto Activation = &silu; - - dgated_act_cast_transpose( - *reinterpret_cast(input), - *reinterpret_cast(swiglu_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - stream); + NVTE_API_CALL(nvte_dswiglu_cast_transpose); + using namespace transformer_engine; + + constexpr auto dActivation = &dsilu; + constexpr auto Activation = &silu; + + dgated_act_cast_transpose( + *reinterpret_cast(input), *reinterpret_cast(swiglu_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + stream); } -void nvte_dreglu_cast_transpose(const NVTETensor input, - const NVTETensor gated_act_input, - NVTETensor cast_output, - NVTETensor transposed_output, +void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, + NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream) { - NVTE_API_CALL(nvte_dreglu_cast_transpose); - using namespace transformer_engine; - - constexpr auto dActivation = &drelu; - constexpr auto Activation = &relu; - - dgated_act_cast_transpose( - *reinterpret_cast(input), - *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - stream); + NVTE_API_CALL(nvte_dreglu_cast_transpose); + using namespace transformer_engine; + + constexpr auto dActivation = &drelu; + constexpr auto Activation = &relu; + + dgated_act_cast_transpose( + *reinterpret_cast(input), *reinterpret_cast(gated_act_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + stream); } -void nvte_dsreglu_cast_transpose(const NVTETensor input, - const NVTETensor gated_act_input, - NVTETensor cast_output, - NVTETensor transposed_output, +void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, + NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream) { - NVTE_API_CALL(nvte_dsreglu_cast_transpose); - using namespace transformer_engine; - - constexpr auto dActivation = &dsrelu; - constexpr auto Activation = &srelu; - - dgated_act_cast_transpose( - *reinterpret_cast(input), - *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - stream); + NVTE_API_CALL(nvte_dsreglu_cast_transpose); + using namespace transformer_engine; + + constexpr auto dActivation = &dsrelu; + constexpr auto Activation = &srelu; + + dgated_act_cast_transpose( + *reinterpret_cast(input), *reinterpret_cast(gated_act_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + stream); } -void nvte_dqgeglu_cast_transpose(const NVTETensor input, - const NVTETensor gated_act_input, - NVTETensor cast_output, - NVTETensor transposed_output, +void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, + NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream) { - NVTE_API_CALL(nvte_dqgeglu_cast_transpose); - using namespace transformer_engine; - - constexpr auto dActivation = &dqgelu; - constexpr auto Activation = &qgelu; - - dgated_act_cast_transpose( - *reinterpret_cast(input), - *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - stream); + NVTE_API_CALL(nvte_dqgeglu_cast_transpose); + using namespace transformer_engine; + + constexpr auto dActivation = &dqgelu; + constexpr auto Activation = &qgelu; + + dgated_act_cast_transpose( + *reinterpret_cast(input), *reinterpret_cast(gated_act_input), + reinterpret_cast(cast_output), reinterpret_cast(transposed_output), + stream); } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index fa370240a002680c6aae8392870fa9d8bd1c254e..8e6e90a7bf619810116db4b9dd2bd5b4eeb30633 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -4,13 +4,15 @@ * See LICENSE for license information. ************************************************************************/ -#include #include -#include +#include + #include +#include #include -#include "../utils.cuh" + #include "../common.h" +#include "../utils.cuh" namespace transformer_engine { @@ -40,21 +42,14 @@ struct MultiCastTransposeArgs { int row_length_list[kMaxTensorsPerKernel]; // Prefix sum (with leading zero) of CUDA blocks needed for each // tensor - int block_range[kMaxTensorsPerKernel+1]; + int block_range[kMaxTensorsPerKernel + 1]; // Number of tensors being processed by kernel int num_tensors; }; -template < - int nvec_in, - int nvec_out, - bool aligned, - typename CType, - typename IType, - typename OType> -__global__ void -__launch_bounds__(threads_per_block) -multi_cast_transpose_kernel(MultiCastTransposeArgs args) { +template +__global__ void __launch_bounds__(threads_per_block) + multi_cast_transpose_kernel(MultiCastTransposeArgs args) { using IVec = Vec; using OVecC = Vec; using OVecT = Vec; @@ -79,7 +74,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { // Find tensor corresponding to block int tensor_id = 0; - while (args.block_range[tensor_id+1] <= bid) { + while (args.block_range[tensor_id + 1] <= bid) { ++tensor_id; } const IType* input = reinterpret_cast(args.input_list[tensor_id]); @@ -104,11 +99,11 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { // type, and transposes in registers. OVecT local_output_t[nvec_in][n_iterations]; CType local_amax = 0; - #pragma unroll +#pragma unroll for (int iter = 0; iter < n_iterations; ++iter) { const int i1 = tidy + iter * bdimy; const int j1 = tidx; - #pragma unroll +#pragma unroll for (int i2 = 0; i2 < nvec_out; ++i2) { const int row = tile_row + i1 * nvec_out + i2; const int col = tile_col + j1 * nvec_in; @@ -119,7 +114,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { } else { local_input.clear(); if (row < num_rows) { - #pragma unroll +#pragma unroll for (int j2 = 0; j2 < nvec_in; ++j2) { if (col + j2 < row_length) { local_input.data.elt[j2] = input[row * row_length + col + j2]; @@ -127,7 +122,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { } } } - #pragma unroll +#pragma unroll for (int j2 = 0; j2 < nvec_in; ++j2) { const CType x = CType(local_input.data.elt[j2]); const OType y = OType(scale * x); @@ -140,7 +135,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { local_output_c.store_to(&output_c[row * row_length + col]); } else { if (row < num_rows) { - #pragma unroll +#pragma unroll for (int j2 = 0; j2 < nvec_in; ++j2) { if (col + j2 < row_length) { output_c[row * row_length + col + j2] = local_output_c.data.elt[j2]; @@ -152,17 +147,17 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { } // Copy transposed output from registers to global memory - __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; - #pragma unroll + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll for (int j2 = 0; j2 < nvec_in; ++j2) { - #pragma unroll +#pragma unroll for (int iter = 0; iter < n_iterations; ++iter) { const int i1 = tidy + iter * bdimy; const int j1 = tidx; shared_output_t[j1][i1] = local_output_t[j2][iter]; } __syncthreads(); - #pragma unroll +#pragma unroll for (int iter = 0; iter < n_iterations; ++iter) { const int i1 = tidx; const int j1 = tidy + iter * bdimy; @@ -172,7 +167,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]); } else { if (col < row_length) { - #pragma unroll +#pragma unroll for (int i2 = 0; i2 < nvec_out; ++i2) { if (row + i2 < num_rows) { output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; @@ -196,8 +191,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { void multi_cast_transpose(const std::vector input_list, std::vector cast_output_list, - std::vector transposed_output_list, - cudaStream_t stream) { + std::vector transposed_output_list, cudaStream_t stream) { // Check that number of tensors is valid NVTE_CHECK(cast_output_list.size() == input_list.size(), "Number of input and C output tensors must match"); @@ -218,15 +212,11 @@ void multi_cast_transpose(const std::vector input_list, CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id)); CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id)); - NVTE_CHECK(input.data.dtype == itype, - "Input tensor types do not match."); - NVTE_CHECK(cast_output.data.dtype == otype, - "C output tensor types do not match."); - NVTE_CHECK(transposed_output.data.dtype == otype, - "T output tensor types do not match."); + NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match."); + NVTE_CHECK(cast_output.data.dtype == otype, "C output tensor types do not match."); + NVTE_CHECK(transposed_output.data.dtype == otype, "T output tensor types do not match."); - NVTE_CHECK(input.data.shape.size() == 2, - "Input tensor must have 2 dimensions."); + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); NVTE_CHECK(cast_output.data.shape == input.data.shape, "C output tensor shape does not match input tensor."); NVTE_CHECK(transposed_output.data.shape.size() == 2, @@ -251,27 +241,28 @@ void multi_cast_transpose(const std::vector input_list, for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { // Launch kernel if argument struct is full if (kernel_args_aligned.num_tensors == kMaxTensorsPerKernel) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, - constexpr int nvec_in = desired_load_size / sizeof(InputType); - constexpr int nvec_out = desired_store_size / sizeof(OutputType); - const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; - multi_cast_transpose_kernel - <<>>(kernel_args_aligned); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + itype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + constexpr int nvec_out = desired_store_size / sizeof(OutputType); + const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; + multi_cast_transpose_kernel + <<>>(kernel_args_aligned);); // NOLINT(*) + ); // NOLINT(*) kernel_args_aligned.num_tensors = 0; } if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, - constexpr int nvec_in = desired_load_size / sizeof(InputType); - constexpr int nvec_out = desired_store_size / sizeof(OutputType); - const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; - multi_cast_transpose_kernel - <<>>(kernel_args_unaligned); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + itype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + constexpr int nvec_out = desired_store_size / sizeof(OutputType); + const int n_blocks = + kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; + multi_cast_transpose_kernel + <<>>(kernel_args_unaligned);); // NOLINT(*) + ); // NOLINT(*) kernel_args_unaligned.num_tensors = 0; } @@ -283,8 +274,8 @@ void multi_cast_transpose(const std::vector input_list, const int num_tiles = num_tiles_m * num_tiles_n; // Figure out whether to use aligned or unaligned kernel - const bool aligned = ((num_tiles_m * tile_dim_m == num_rows) - && (num_tiles_n * tile_dim_n == row_length)); + const bool aligned = + ((num_tiles_m * tile_dim_m == num_rows) && (num_tiles_n * tile_dim_n == row_length)); auto& kernel_args = aligned ? kernel_args_aligned : kernel_args_unaligned; // Add tensor to kernel argument struct @@ -296,53 +287,48 @@ void multi_cast_transpose(const std::vector input_list, kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; kernel_args.num_rows_list[pos] = num_rows; kernel_args.row_length_list[pos] = row_length; - kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles; + kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; kernel_args.num_tensors++; } // Launch kernel if (kernel_args_aligned.num_tensors > 0) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, - constexpr int nvec_in = desired_load_size / sizeof(InputType); - constexpr int nvec_out = desired_store_size / sizeof(OutputType); - const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; - multi_cast_transpose_kernel - <<>>(kernel_args_aligned); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + itype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + constexpr int nvec_out = desired_store_size / sizeof(OutputType); + const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; + multi_cast_transpose_kernel + <<>>(kernel_args_aligned);); // NOLINT(*) + ); // NOLINT(*) } if (kernel_args_unaligned.num_tensors > 0) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, - constexpr int nvec_in = desired_load_size / sizeof(InputType); - constexpr int nvec_out = desired_store_size / sizeof(OutputType); - const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; - multi_cast_transpose_kernel - <<>>(kernel_args_unaligned); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + itype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + constexpr int nvec_out = desired_store_size / sizeof(OutputType); + const int n_blocks = + kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; + multi_cast_transpose_kernel + <<>>(kernel_args_unaligned);); // NOLINT(*) + ); // NOLINT(*) } } } // namespace transformer_engine -void nvte_multi_cast_transpose(size_t num_tensors, - const NVTETensor* input_list, - NVTETensor* cast_output_list, - NVTETensor* transposed_output_list, +void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, + NVTETensor* cast_output_list, NVTETensor* transposed_output_list, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_cast_transpose); using namespace transformer_engine; - std::vector input_list_, - cast_output_list_, transposed_output_list_; + std::vector input_list_, cast_output_list_, transposed_output_list_; for (size_t i = 0; i < num_tensors; ++i) { input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); cast_output_list_.push_back(reinterpret_cast(cast_output_list[i])); transposed_output_list_.push_back(reinterpret_cast(transposed_output_list[i])); } - multi_cast_transpose(input_list_, - cast_output_list_, - transposed_output_list_, - stream); + multi_cast_transpose(input_list_, cast_output_list_, transposed_output_list_, stream); } diff --git a/transformer_engine/common/transpose/rtc/cast_transpose.cu b/transformer_engine/common/transpose/rtc/cast_transpose.cu index d5035817182ad6858468f2cb08023c15365b5a0c..6ea83261473c4faee0973cf5aeb6e92d48b47ec8 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose.cu @@ -21,16 +21,11 @@ constexpr size_t block_size = __BLOCK_SIZE__; } // namespace -__global__ void -__launch_bounds__(block_size) -cast_transpose_optimized_kernel(const IType * __restrict__ const input, - const CType * __restrict__ const noop, - OType * __restrict__ const output_c, - OType * __restrict__ const output_t, - const CType * __restrict__ const scale_ptr, - CType * __restrict__ const amax_ptr, - const size_t row_length, - const size_t num_rows) { +__global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( + const IType* __restrict__ const input, const CType* __restrict__ const noop, + OType* __restrict__ const output_c, OType* __restrict__ const output_t, + const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr, + const size_t row_length, const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -73,18 +68,18 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input, // Note: Each thread loads num_iterations subtiles, computes amax, // casts type, and transposes in registers. OVecT local_output_t[nvec_in][num_iterations]; - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidy + iter * bdimy; const size_t j1 = tidx; - #pragma unroll +#pragma unroll for (size_t i2 = 0; i2 < nvec_out; ++i2) { const size_t row = tile_row + i1 * nvec_out + i2; const size_t col = tile_col + j1 * nvec_in; IVec local_input; OVecC local_output_c; local_input.load_from(&input[row * row_length + col]); - #pragma unroll +#pragma unroll for (size_t j2 = 0; j2 < nvec_in; ++j2) { const CType in = static_cast(local_input.data.elt[j2]); const OType out = OType(in * scale); @@ -98,17 +93,17 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input, } // Copy from registers to shared memory to global memory - __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; - #pragma unroll + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll for (size_t j2 = 0; j2 < nvec_in; ++j2) { - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidy + iter * bdimy; const size_t j1 = tidx; shared_output_t[j1][i1] = local_output_t[j2][iter]; } __syncthreads(); - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidx; const size_t j1 = tidy + iter * bdimy; diff --git a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu index a5c1c661c89ae4a2ce794d47eb5d6eebd22346d3..c005be98efd767c49523cb25aacc20b69eb952e0 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu @@ -4,25 +4,25 @@ * See LICENSE for license information. ************************************************************************/ -#include "utils.cuh" #include "util/math.h" +#include "utils.cuh" using namespace transformer_engine; namespace { // Parameters -using CType = float; -using IType = __ITYPE__; +using CType = float; +using IType = __ITYPE__; using IType2 = __ITYPE2__; -using OType = __OTYPE__; -constexpr size_t LOAD_SIZE = __LOAD_SIZE__; -constexpr size_t STORE_SIZE = __STORE_SIZE__; -constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__; -constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__; -constexpr bool IS_DBIAS = __IS_DBIAS__; -constexpr bool IS_DACT = __IS_DACT__; -constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__; +using OType = __OTYPE__; +constexpr size_t LOAD_SIZE = __LOAD_SIZE__; +constexpr size_t STORE_SIZE = __STORE_SIZE__; +constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__; +constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__; +constexpr bool IS_DBIAS = __IS_DBIAS__; +constexpr bool IS_DACT = __IS_DACT__; +constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__; constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType); constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType); @@ -32,218 +32,209 @@ using IVec2 = Vec; using OVec = Vec; using Param = CTDBiasDActParam; -using OP = CType (*)(const CType, const Empty&); +using OP = CType (*)(const CType, const Empty &); constexpr OP Activation[] = { - nullptr, // 0 - &dsigmoid, // 1 - &dgelu, // 2 - &dqgelu, // 3 - &dsilu, // 4 - &drelu, // 5 - &dsrelu // 6 + nullptr, // 0 + &dsigmoid, // 1 + &dgelu, // 2 + &dqgelu, // 3 + &dsilu, // 4 + &drelu, // 5 + &dsrelu // 6 }; } // namespace -inline __device__ void -cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT], - OVec (&out_trans)[NVEC_IN], - CVec &out_dbias, // NOLINT(*) - typename OVec::type *output_cast_tile, - const size_t current_place, - const size_t stride, - const CType scale, - CType &amax, // NOLINT(*) - const int dbias_shfl_src_lane) { - using OVecC = Vec; - - CVec step_dbias; - if constexpr (IS_DBIAS) { - step_dbias.clear(); - } - - #pragma unroll - for (unsigned int i = 0; i < NVEC_OUT; ++i) { - OVecC out_cast; - #pragma unroll - for (unsigned int j = 0; j < NVEC_IN; ++j) { - const CType tmp = in[i].data.elt[j]; - if constexpr (IS_DBIAS) { - step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation - } - out_cast.data.elt[j] = static_cast(tmp * scale); - out_trans[j].data.elt[i] = static_cast(tmp * scale); // thread tile transpose - - __builtin_assume(amax >= 0); - amax = fmaxf(fabsf(tmp), amax); - } - out_cast.store_to(output_cast_tile, current_place + stride * i); +inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT], + OVec (&out_trans)[NVEC_IN], + CVec &out_dbias, // NOLINT(*) + typename OVec::type *output_cast_tile, + const size_t current_place, + const size_t stride, const CType scale, + CType &amax, // NOLINT(*) + const int dbias_shfl_src_lane) { + using OVecC = Vec; + + CVec step_dbias; + if constexpr (IS_DBIAS) { + step_dbias.clear(); + } + +#pragma unroll + for (unsigned int i = 0; i < NVEC_OUT; ++i) { + OVecC out_cast; +#pragma unroll + for (unsigned int j = 0; j < NVEC_IN; ++j) { + const CType tmp = in[i].data.elt[j]; + if constexpr (IS_DBIAS) { + step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation + } + out_cast.data.elt[j] = static_cast(tmp * scale); + out_trans[j].data.elt[i] = static_cast(tmp * scale); // thread tile transpose + + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(tmp), amax); } - - if constexpr (IS_DBIAS) { - #pragma unroll - for (unsigned int j = 0; j < NVEC_IN; ++j) { - CType elt = step_dbias.data.elt[j]; - elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp - out_dbias.data.elt[j] += elt; - } + out_cast.store_to(output_cast_tile, current_place + stride * i); + } + + if constexpr (IS_DBIAS) { +#pragma unroll + for (unsigned int j = 0; j < NVEC_IN; ++j) { + CType elt = step_dbias.data.elt[j]; + elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp + out_dbias.data.elt[j] += elt; } + } } -__global__ void -__launch_bounds__(BLOCK_SIZE) -cast_transpose_fusion_kernel_optimized(const Param param, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE) - + warp_id / WARPS_PER_TILE; - if (tile_id >= num_tiles) { - return; +__global__ void __launch_bounds__(BLOCK_SIZE) + cast_transpose_fusion_kernel_optimized(const Param param, const size_t row_length, + const size_t num_rows, const size_t num_tiles) { + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE) + warp_id / WARPS_PER_TILE; + if (tile_id >= num_tiles) { + return; + } + + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const size_t tile_offset = + (tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT) * THREADS_PER_WARP; + const size_t tile_offset_transp = + (tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN) * THREADS_PER_WARP; + + const IType *const my_input_tile = param.input + tile_offset; + const IType2 *const my_act_input_tile = param.act_input + tile_offset; + OType *const my_output_c_tile = param.output_c + tile_offset; + OType *const my_output_t_tile = param.output_t + tile_offset_transp; + CType *const my_partial_dbias_tile = + param.workspace + (tile_id_x * (NVEC_IN * THREADS_PER_WARP) + tile_id_y * row_length); + + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + CVec *const my_dbias_scratch = reinterpret_cast(scratch); + + IVec in[2][NVEC_OUT]; + IVec2 act_in[2][NVEC_OUT]; + + const unsigned int warp_id_in_tile = warp_id % WARPS_PER_TILE; + constexpr unsigned int n_iterations = THREADS_PER_WARP / WARPS_PER_TILE; + OVec out_space[n_iterations][NVEC_IN]; + + const size_t stride = row_length / NVEC_IN; + const size_t output_stride = num_rows / NVEC_OUT; + size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride; + size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + + CType amax = 0.0f; + const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; + + CVec partial_dbias; + if constexpr (IS_DBIAS) { + partial_dbias.clear(); + } + +#pragma unroll + for (unsigned int i = 0; i < NVEC_OUT; ++i) { + in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + if constexpr (IS_DACT) { + act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i); } - - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const size_t tile_offset = (tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT) - * THREADS_PER_WARP; - const size_t tile_offset_transp = (tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN) - * THREADS_PER_WARP; - - const IType * const my_input_tile = param.input + tile_offset; - const IType2 * const my_act_input_tile = param.act_input + tile_offset; - OType * const my_output_c_tile = param.output_c + tile_offset; - OType * const my_output_t_tile = param.output_t + tile_offset_transp; - CType * const my_partial_dbias_tile = param.workspace - + (tile_id_x * (NVEC_IN * THREADS_PER_WARP) - + tile_id_y * row_length); - - OVec * const my_scratch = reinterpret_cast(scratch) - + (my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP) - * (THREADS_PER_WARP + 1); - - CVec * const my_dbias_scratch = reinterpret_cast(scratch); - - IVec in[2][NVEC_OUT]; - IVec2 act_in[2][NVEC_OUT]; - - const unsigned int warp_id_in_tile = warp_id % WARPS_PER_TILE; - constexpr unsigned int n_iterations = THREADS_PER_WARP / WARPS_PER_TILE; - OVec out_space[n_iterations][NVEC_IN]; - - const size_t stride = row_length / NVEC_IN; - const size_t output_stride = num_rows / NVEC_OUT; - size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride; - size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) - % THREADS_PER_WARP; - - CType amax = 0.0f; - const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; - - CVec partial_dbias; - if constexpr (IS_DBIAS) { - partial_dbias.clear(); - } - - #pragma unroll - for (unsigned int i = 0; i < NVEC_OUT; ++i) { - in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { +#pragma unroll + for (unsigned int j = 0; j < NVEC_OUT; ++j) { + const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j); + in[current_in][j].load_from(my_input_tile, ld_offset); if constexpr (IS_DACT) { - act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i); + act_in[current_in][j].load_from(my_act_input_tile, ld_offset); } + } } - #pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { - #pragma unroll - for (unsigned int j = 0; j < NVEC_OUT; ++j) { - const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j); - in[current_in][j].load_from(my_input_tile, ld_offset); - if constexpr (IS_DACT) { - act_in[current_in][j].load_from(my_act_input_tile, ld_offset); - } - } - } - CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*) - #pragma unroll - for (unsigned int j = 0; j < NVEC_OUT; ++j) { - #pragma unroll - for (unsigned int k = 0; k < NVEC_IN; ++k) { - if constexpr (IS_DACT) { - in_cast_fp32[j].data.elt[k] = - static_cast(in[current_in ^ 1][j].data.elt[k]) - * Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); - } else { - in_cast_fp32[j].data.elt[k] = - static_cast(in[current_in ^ 1][j].data.elt[k]); - } - } + CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*) +#pragma unroll + for (unsigned int j = 0; j < NVEC_OUT; ++j) { +#pragma unroll + for (unsigned int k = 0; k < NVEC_IN; ++k) { + if constexpr (IS_DACT) { + in_cast_fp32[j].data.elt[k] = + static_cast(in[current_in ^ 1][j].data.elt[k]) * + Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); + } else { + in_cast_fp32[j].data.elt[k] = static_cast(in[current_in ^ 1][j].data.elt[k]); } + } + } - const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations) - % THREADS_PER_WARP; + const int dbias_shfl_src_lane = + (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP; - cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias, - my_output_c_tile, current_place, - stride, scale, amax, dbias_shfl_src_lane); + cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias, my_output_c_tile, + current_place, stride, scale, amax, dbias_shfl_src_lane); - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += NVEC_OUT * stride; - current_row += NVEC_OUT; - } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += NVEC_OUT * stride; + current_row += NVEC_OUT; + } - #pragma unroll - for (unsigned int i = 0; i < NVEC_IN; ++i) { - #pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; - } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) - % THREADS_PER_WARP; - current_stride = i * output_stride - + warp_id_in_tile * n_iterations * output_stride * NVEC_IN; - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, - current_stride + my_place); - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * NVEC_IN; - } - __syncthreads(); +#pragma unroll + for (unsigned int i = 0; i < NVEC_IN; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space[j][i]; } - - if constexpr (IS_DBIAS) { - my_dbias_scratch[threadIdx.x] = partial_dbias; - __syncthreads(); - if (warp_id_in_tile == 0) { - #pragma unroll - for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) { - CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; - #pragma unroll - for (unsigned int j = 0; j < NVEC_IN; ++j) { - partial_dbias.data.elt[j] += tmp.data.elt[j]; - } - } - partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * NVEC_IN; + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, + current_stride + my_place); + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * NVEC_IN; + } + __syncthreads(); + } + + if constexpr (IS_DBIAS) { + my_dbias_scratch[threadIdx.x] = partial_dbias; + __syncthreads(); + if (warp_id_in_tile == 0) { +#pragma unroll + for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) { + CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; +#pragma unroll + for (unsigned int j = 0; j < NVEC_IN; ++j) { + partial_dbias.data.elt[j] += tmp.data.elt[j]; } + } + partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); } + } - // warp tile amax reduce - const CType max_block = reduce_max(amax, warp_id); + // warp tile amax reduce + const CType max_block = reduce_max(amax, warp_id); - if (threadIdx.x == 0) { - if (param.amax != nullptr) { - atomicMaxFloat(param.amax, max_block); - } + if (threadIdx.x == 0) { + if (param.amax != nullptr) { + atomicMaxFloat(param.amax, max_block); } + } } diff --git a/transformer_engine/common/transpose/rtc/transpose.cu b/transformer_engine/common/transpose/rtc/transpose.cu index f21014866ba11dca62b859b2b1640936db005d1d..09758698f65cf5f8fe6fbecc40bb1fdcb454a79f 100644 --- a/transformer_engine/common/transpose/rtc/transpose.cu +++ b/transformer_engine/common/transpose/rtc/transpose.cu @@ -19,13 +19,10 @@ constexpr size_t block_size = __BLOCK_SIZE__; } // namespace -__global__ void -__launch_bounds__(block_size) -transpose_optimized_kernel(const Type * __restrict__ const input, - const float * const noop, - Type * __restrict__ const output, - const size_t row_length, - const size_t num_rows) { +__global__ void __launch_bounds__(block_size) + transpose_optimized_kernel(const Type* __restrict__ const input, const float* const noop, + Type* __restrict__ const output, const size_t row_length, + const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -63,17 +60,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input, // Note: Each thread loads num_iterations subtiles and transposes in // registers. OVec local_output[nvec_in][num_iterations]; - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidy + iter * bdimy; const size_t j1 = tidx; - #pragma unroll +#pragma unroll for (size_t i2 = 0; i2 < nvec_out; ++i2) { const size_t row = tile_row + i1 * nvec_out + i2; const size_t col = tile_col + j1 * nvec_in; IVec local_input; local_input.load_from(&input[row * row_length + col]); - #pragma unroll +#pragma unroll for (size_t j2 = 0; j2 < nvec_in; ++j2) { local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2]; } @@ -81,17 +78,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input, } // Copy from registers to shared memory to global memory - __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP+1]; - #pragma unroll + __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll for (size_t j2 = 0; j2 < nvec_in; ++j2) { - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidy + iter * bdimy; const size_t j1 = tidx; shared_output[j1][i1] = local_output[j2][iter]; } __syncthreads(); - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidx; const size_t j1 = tidy + iter * bdimy; diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index c0a1a7fbcf96d5cda5e2ad719c3dd23f213aba6a..115acafe22278e704130255828111459c4c3f9e9 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -4,13 +4,12 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include -#include - #include "../common.h" #include "../util/rtc.h" #include "../util/string.h" @@ -46,24 +45,18 @@ struct KernelConfig { /* Elements per L1 cache store */ size_t elements_per_store = 0; - KernelConfig(size_t row_length, - size_t num_rows, - size_t type_size, - size_t load_size_, + KernelConfig(size_t row_length, size_t num_rows, size_t type_size, size_t load_size_, size_t store_size_) - : load_size{load_size_} - , store_size{store_size_} { + : load_size{load_size_}, store_size{store_size_} { // Check that tiles are correctly aligned constexpr size_t cache_line_size = 128; - if (load_size % type_size != 0 - || store_size % type_size != 0 - || cache_line_size % type_size != 0) { + if (load_size % type_size != 0 || store_size % type_size != 0 || + cache_line_size % type_size != 0) { return; } const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size; const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size; - valid = (row_length % row_tile_elements == 0 - && num_rows % col_tile_elements == 0); + valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0); if (!valid) { return; } @@ -75,10 +68,8 @@ struct KernelConfig { constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), static_cast(cuda::sm_count())); - elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) - / type_size); - elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size) - / type_size); + elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) / type_size); + elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size) / type_size); } /* Compare by estimated cost */ @@ -93,8 +84,8 @@ struct KernelConfig { const auto &s2 = other.elements_per_store; const auto &p2 = other.active_sm_count; const auto scale = l1 * s1 * p1 * l2 * s2 * p2; - const auto cost1 = (scale/l1 + scale/s1) / p1; - const auto cost2 = (scale/l2 + scale/s2) / p2; + const auto cost1 = (scale / l1 + scale / s1) / p1; + const auto cost2 = (scale / l2 + scale / s2) / p2; return cost1 < cost2; } else { return this->valid && !other.valid; @@ -103,13 +94,10 @@ struct KernelConfig { }; template -__global__ void -__launch_bounds__(block_size) -transpose_general_kernel(const Type * __restrict__ const input, - const fp32 * const noop, - Type * __restrict__ const output, - const size_t row_length, - const size_t num_rows) { +__global__ void __launch_bounds__(block_size) + transpose_general_kernel(const Type *__restrict__ const input, const fp32 *const noop, + Type *__restrict__ const output, const size_t row_length, + const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -147,25 +135,25 @@ transpose_general_kernel(const Type * __restrict__ const input, // Note: Each thread loads num_iterations subtiles and transposes in // registers. OVec local_output[nvec_in][num_iterations]; - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidy + iter * bdimy; const size_t j1 = tidx; - #pragma unroll +#pragma unroll for (size_t i2 = 0; i2 < nvec_out; ++i2) { const size_t row = tile_row + i1 * nvec_out + i2; const size_t col = tile_col + j1 * nvec_in; IVec local_input; local_input.clear(); if (row < num_rows) { - #pragma unroll +#pragma unroll for (size_t j2 = 0; j2 < nvec_in; ++j2) { if (col + j2 < row_length) { local_input.data.elt[j2] = input[row * row_length + col + j2]; } } } - #pragma unroll +#pragma unroll for (size_t j2 = 0; j2 < nvec_in; ++j2) { local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2]; } @@ -173,24 +161,24 @@ transpose_general_kernel(const Type * __restrict__ const input, } // Copy transposed output from registers to global memory - __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP+1]; - #pragma unroll + __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll for (size_t j2 = 0; j2 < nvec_in; ++j2) { - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidy + iter * bdimy; const size_t j1 = tidx; shared_output[j1][i1] = local_output[j2][iter]; } __syncthreads(); - #pragma unroll +#pragma unroll for (size_t iter = 0; iter < num_iterations; ++iter) { const size_t i1 = tidx; const size_t j1 = tidy + iter * bdimy; const size_t row = tile_row + i1 * nvec_out; const size_t col = tile_col + j1 * nvec_in + j2; if (col < row_length) { - #pragma unroll +#pragma unroll for (size_t i2 = 0; i2 < nvec_out; ++i2) { if (row + i2 < num_rows) { output[col * num_rows + row + i2] = shared_output[j1][i1].data.elt[i2]; @@ -204,10 +192,7 @@ transpose_general_kernel(const Type * __restrict__ const input, } // namespace -void transpose(const Tensor &input, - const Tensor &noop, - Tensor *output_, - cudaStream_t stream) { +void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream) { Tensor &output = *output_; NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(output.data.shape.size() == 2, "Output must have 2 dimensions."); @@ -219,121 +204,106 @@ void transpose(const Tensor &input, NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); - NVTE_CHECK(input.data.dtype == output.data.dtype, - "Input and output type must match."); + NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); // Number of elements in tensor - auto numel = [] (const Tensor &tensor) -> size_t { + auto numel = [](const Tensor &tensor) -> size_t { size_t acc = 1; - for (const auto& dim : tensor.data.shape) { + for (const auto &dim : tensor.data.shape) { acc *= dim; } return acc; }; if (noop.data.dptr != nullptr) { - NVTE_CHECK(numel(noop) == 1, - "Expected 1 element, ", - "but found ", numel(noop), "."); + NVTE_CHECK(numel(noop) == 1, "Expected 1 element, ", "but found ", numel(noop), "."); NVTE_CHECK(noop.data.dtype == DType::kFloat32); NVTE_CHECK(noop.data.dptr != nullptr); } - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type, - constexpr const char *type_name = TypeInfo::name; - constexpr size_t type_size = sizeof(Type); - - // Choose between runtime-compiled or statically-compiled kernel - const bool aligned = (row_length % THREADS_PER_WARP == 0 - && num_rows % THREADS_PER_WARP == 0); - if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel - // Pick kernel config - std::vector kernel_configs; - kernel_configs.reserve(16); - auto add_config = [&](size_t load_size, size_t store_size) { - kernel_configs.emplace_back(row_length, num_rows, type_size, - load_size, store_size); - }; - add_config(8, 8); - add_config(4, 8); add_config(8, 4); - add_config(4, 4); - add_config(2, 8); add_config(8, 2); - add_config(2, 4); add_config(4, 2); - add_config(2, 2); - add_config(1, 8); add_config(8, 1); - add_config(1, 4); add_config(4, 1); - add_config(1, 2); add_config(2, 1); - add_config(1, 1); - const auto &kernel_config = *std::min_element(kernel_configs.begin(), - kernel_configs.end()); - NVTE_CHECK(kernel_config.valid, "invalid kernel config"); - const size_t load_size = kernel_config.load_size; - const size_t store_size = kernel_config.store_size; - const size_t num_blocks = kernel_config.num_blocks; - - // Compile NVRTC kernel if needed and launch - auto& rtc_manager = rtc::KernelManager::instance(); - const std::string kernel_label = concat_strings("transpose" - ",type=", type_name, - ",load_size=", load_size, - ",store_size=", store_size); - if (!rtc_manager.is_compiled(kernel_label)) { - std::string code = string_code_transpose_rtc_transpose_cu; - code = regex_replace(code, "__TYPE__", type_name); - code = regex_replace(code, "__LOAD_SIZE__", load_size); - code = regex_replace(code, "__STORE_SIZE__", store_size); - code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); - code = regex_replace(code, "__BLOCK_SIZE__", block_size); - rtc_manager.compile(kernel_label, - "transpose_optimized_kernel", - code, - "transformer_engine/common/transpose/rtc/transpose.cu"); - } - rtc_manager.launch(kernel_label, - num_blocks, block_size, 0, stream, - static_cast(input.data.dptr), - static_cast(noop.data.dptr), - static_cast(output.data.dptr), - row_length, num_rows); - } else { // Statically-compiled general kernel - constexpr size_t load_size = 4; - constexpr size_t store_size = 4; - constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP; - constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP; - const int num_blocks = (DIVUP(row_length, row_tile_size) - * DIVUP(num_rows, col_tile_size)); - transpose_general_kernel<<>>( - static_cast(input.data.dptr), - static_cast(noop.data.dptr), - static_cast(output.data.dptr), - row_length, num_rows); - } - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + input.data.dtype, Type, constexpr const char *type_name = TypeInfo::name; + constexpr size_t type_size = sizeof(Type); + + // Choose between runtime-compiled or statically-compiled kernel + const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); + if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, type_size, load_size, store_size); + }; + add_config(8, 8); + add_config(4, 8); + add_config(8, 4); + add_config(4, 4); + add_config(2, 8); + add_config(8, 2); + add_config(2, 4); + add_config(4, 2); + add_config(2, 2); + add_config(1, 8); + add_config(8, 1); + add_config(1, 4); + add_config(4, 1); + add_config(1, 2); + add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = *std::min_element(kernel_configs.begin(), kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; + + // Compile NVRTC kernel if needed and launch + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings( + "transpose" + ",type=", + type_name, ",load_size=", load_size, ",store_size=", store_size); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_transpose_cu; + code = regex_replace(code, "__TYPE__", type_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + rtc_manager.compile(kernel_label, "transpose_optimized_kernel", code, + "transformer_engine/common/transpose/rtc/transpose.cu"); + } + rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, + static_cast(input.data.dptr), + static_cast(noop.data.dptr), + static_cast(output.data.dptr), row_length, num_rows); + } else { // Statically-compiled general kernel + constexpr size_t load_size = 4; + constexpr size_t store_size = 4; + constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP; + constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP; + const int num_blocks = (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); + transpose_general_kernel + <<>>(static_cast(input.data.dptr), + static_cast(noop.data.dptr), + static_cast(output.data.dptr), + row_length, num_rows); + }); // NOLINT(*) } } // namespace transformer_engine -void nvte_transpose(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_transpose); using namespace transformer_engine; auto noop = Tensor(); - transpose(*reinterpret_cast(input), - noop, - reinterpret_cast(output), + transpose(*reinterpret_cast(input), noop, reinterpret_cast(output), stream); } - -void nvte_transpose_with_noop(const NVTETensor input, - const NVTETensor noop, - NVTETensor output, +void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_transpose_with_noop); using namespace transformer_engine; - transpose(*reinterpret_cast(input), - *reinterpret_cast(noop), - reinterpret_cast(output), - stream); + transpose(*reinterpret_cast(input), *reinterpret_cast(noop), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index e7be1804af25aac6d5da7baecdbfe741146799ed..c03237194064ffc9dc77fff2335ff4d6eedff67a 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -4,18 +4,19 @@ * See LICENSE for license information. ************************************************************************/ -#include #include +#include + #include #include #include -#include "../utils.cuh" + #include "../common.h" +#include "../utils.cuh" namespace transformer_engine { -template +template inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out], OVec (&out_trans)[nvec_in], CVec &out_dbias, // NOLINT(*) @@ -24,7 +25,8 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out], using T = typename OVec::type; using OVecC = Vec; - CVec step_dbias; step_dbias.clear(); + CVec step_dbias; + step_dbias.clear(); #pragma unroll for (unsigned int i = 0; i < nvec_out; ++i) { @@ -61,24 +63,21 @@ namespace { template struct TDBiasParam { - using InputType = IType; - using OutputType = OType; - using ComputeType = CType; - const IType *input; - OType *output_t; - const CType *scale_inv; - CType *workspace; + using InputType = IType; + using OutputType = OType; + using ComputeType = CType; + const IType *input; + OType *output_t; + const CType *scale_inv; + CType *workspace; }; } // namespace template -__global__ void -__launch_bounds__(cast_transpose_num_threads) -transpose_dbias_kernel(const Param param, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { +__global__ void __launch_bounds__(cast_transpose_num_threads) + transpose_dbias_kernel(const Param param, const size_t row_length, const size_t num_rows, + const size_t num_tiles) { using IType = typename Param::InputType; using OType = typename Param::OutputType; using CType = typename Param::ComputeType; @@ -92,27 +91,24 @@ transpose_dbias_kernel(const Param param, const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); // const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + - warp_id / n_warps_per_tile; + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; if (tile_id >= num_tiles) return; const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x; - const IType * const my_input_tile = param.input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - CType * const my_partial_dbias_tile = param.workspace + - (tile_id_x * (nvec_in * THREADS_PER_WARP) + - tile_id_y * row_length); + const IType *const my_input_tile = + param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + OType *const my_output_t_tile = + param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + CType *const my_partial_dbias_tile = + param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length); - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); - CVec * const my_dbias_scratch = reinterpret_cast(scratch); + CVec *const my_dbias_scratch = reinterpret_cast(scratch); IVec in[2][nvec_out]; const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; @@ -123,9 +119,8 @@ transpose_dbias_kernel(const Param param, const size_t stride = row_length / nvec_in; const size_t output_stride = num_rows / nvec_out; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1; partial_dbias.clear(); @@ -147,11 +142,8 @@ transpose_dbias_kernel(const Param param, } OVec out_trans[nvec_in]; // NOLINT(*) transpose_regs_partial_dbias( - in[current_in ^ 1], - out_trans, - partial_dbias, - scale_inv, - (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP); + in[current_in ^ 1], out_trans, partial_dbias, scale_inv, + (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP); #pragma unroll for (unsigned int j = 0; j < nvec_in; ++j) { @@ -164,14 +156,13 @@ transpose_dbias_kernel(const Param param, for (unsigned int i = 0; i < nvec_in; ++i) { #pragma unroll for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space[j][i]; } __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; for (unsigned int j = 0; j < n_iterations; ++j) { my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, current_stride + my_place); @@ -199,12 +190,9 @@ transpose_dbias_kernel(const Param param, } template -__global__ void -__launch_bounds__(cast_transpose_num_threads) -transpose_dbias_kernel_notaligned(const Param param, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { +__global__ void __launch_bounds__(cast_transpose_num_threads) + transpose_dbias_kernel_notaligned(const Param param, const size_t row_length, + const size_t num_rows, const size_t num_tiles) { using IType = typename Param::InputType; using OType = typename Param::OutputType; using CType = typename Param::ComputeType; @@ -216,38 +204,35 @@ transpose_dbias_kernel_notaligned(const Param param, const int warp_id = threadIdx.x / THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / - (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + - warp_id / n_warps_per_tile; + const size_t num_tiles_x = + (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; if (tile_id >= num_tiles) return; const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x; - const IType * const my_input_tile = param.input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - CType * const my_partial_dbias_tile = param.workspace + - (tile_id_x * (nvec_in * THREADS_PER_WARP) + - tile_id_y * row_length); + const IType *const my_input_tile = + param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + OType *const my_output_t_tile = + param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + CType *const my_partial_dbias_tile = + param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length); const size_t stride = row_length / nvec_in; const size_t output_stride = num_rows / nvec_out; const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; - const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_length_rest; - const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_height_rest; + const unsigned int tile_length = + row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest; + const unsigned int tile_height = + row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest; - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); - CVec * const my_dbias_scratch = reinterpret_cast(scratch); + CVec *const my_dbias_scratch = reinterpret_cast(scratch); IVec in[2][nvec_out]; const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; @@ -256,16 +241,14 @@ transpose_dbias_kernel_notaligned(const Param param, CVec partial_dbias; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1; partial_dbias.clear(); { - const bool valid_load = my_place < tile_length && - warp_id_in_tile * n_iterations < tile_height; + const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height; #pragma unroll for (unsigned int i = 0; i < nvec_out; ++i) { if (valid_load) { @@ -280,8 +263,8 @@ transpose_dbias_kernel_notaligned(const Param param, const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int current_in = (i + 1) % 2; if (i < n_iterations - 1) { - const bool valid_load = my_place_in < tile_length && - warp_id_in_tile * n_iterations + i + 1 < tile_height; + const bool valid_load = + my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height; #pragma unroll for (unsigned int j = 0; j < nvec_out; ++j) { if (valid_load) { @@ -294,11 +277,8 @@ transpose_dbias_kernel_notaligned(const Param param, } OVec out_trans[nvec_in]; // NOLINT(*) transpose_regs_partial_dbias( - in[current_in ^ 1], - out_trans, - partial_dbias, - scale_inv, - (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP); + in[current_in ^ 1], out_trans, partial_dbias, scale_inv, + (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP); #pragma unroll for (unsigned int j = 0; j < nvec_in; ++j) { @@ -311,14 +291,13 @@ transpose_dbias_kernel_notaligned(const Param param, for (unsigned int i = 0; i < nvec_in; ++i) { #pragma unroll for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space[j][i]; } __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { const bool valid_store = my_place < tile_height; if (valid_store) { @@ -352,27 +331,25 @@ transpose_dbias_kernel_notaligned(const Param param, constexpr size_t reduce_dbias_num_threads = 256; -template -__global__ void -__launch_bounds__(reduce_dbias_num_threads) -reduce_dbias_kernel(OutputType* const dbias_output, - const ComputeType* const dbias_partial, - const int row_length, - const int num_rows) { +template +__global__ void __launch_bounds__(reduce_dbias_num_threads) + reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial, + const int row_length, const int num_rows) { using ComputeVec = Vec; - using OutputVec = Vec; + using OutputVec = Vec; const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; if (thread_id * nvec >= row_length) return; - const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec; - OutputType* const thread_out_base = dbias_output + thread_id * nvec; + const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec; + OutputType *const thread_out_base = dbias_output + thread_id * nvec; const int stride_in_vec = row_length / nvec; ComputeVec ldg_vec; - ComputeVec acc_vec; acc_vec.clear(); + ComputeVec acc_vec; + acc_vec.clear(); for (int i = 0; i < num_rows; ++i) { ldg_vec.load_from(thread_in_base, i * stride_in_vec); #pragma unroll @@ -381,7 +358,7 @@ reduce_dbias_kernel(OutputType* const dbias_output, } } - OutputVec stg_vec; + OutputVec stg_vec; #pragma unroll for (int e = 0; e < nvec; ++e) { stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]); @@ -390,10 +367,9 @@ reduce_dbias_kernel(OutputType* const dbias_output, } void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ - Tensor* workspace, - const int nvec_out) { + Tensor *workspace, const int nvec_out) { const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; + const size_t num_rows = input.data.shape[0]; const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); @@ -405,37 +381,28 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ } template -void reduce_dbias(const Tensor &workspace, Tensor *dbias, - const size_t row_length, const size_t num_rows, const int nvec_out, - cudaStream_t stream) { - constexpr int reduce_dbias_store_bytes = 8; // stg.64 - constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(BiasType); +void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length, + const size_t num_rows, const int nvec_out, cudaStream_t stream) { + constexpr int reduce_dbias_store_bytes = 8; // stg.64 + constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(BiasType); NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape."); const size_t reduce_dbias_row_length = row_length; - const size_t reduce_dbias_num_rows = DIVUP(num_rows, - static_cast(nvec_out * - THREADS_PER_WARP)); - const size_t reduce_dbias_num_blocks = DIVUP(row_length, - reduce_dbias_num_threads * reduce_dbias_nvec); + const size_t reduce_dbias_num_rows = + DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); + const size_t reduce_dbias_num_blocks = + DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec); reduce_dbias_kernel - <<>>( - reinterpret_cast(dbias->data.dptr), - reinterpret_cast(workspace.data.dptr), - reduce_dbias_row_length, - reduce_dbias_num_rows); + <<>>( + reinterpret_cast(dbias->data.dptr), + reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, + reduce_dbias_num_rows); } -void fp8_transpose_dbias(const Tensor &input, - Tensor *transposed_output, - Tensor *dbias, - Tensor *workspace, - cudaStream_t stream) { +void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { CheckInputTensor(input, "fp8_transpose_dbias_input"); CheckOutputTensor(*transposed_output, "transposed_output"); CheckOutputTensor(*dbias, "dbias"); @@ -449,82 +416,71 @@ void fp8_transpose_dbias(const Tensor &input, NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); NVTE_CHECK(transposed_output->data.dtype == input.data.dtype, - "T output must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{ row_length }, "Wrong shape of DBias."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dbias->data.dtype, BiasType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, Type, - constexpr int type_size = sizeof(Type); - constexpr int nvec_in = desired_load_size / type_size; - constexpr int nvec_out = desired_store_size / type_size; - - if (workspace->data.dptr == nullptr) { - populate_transpose_dbias_workspace_config(input, workspace, nvec_out); - return; - } - - NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); - NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); - const size_t n_tiles = DIVUP(row_length, static_cast(nvec_in * THREADS_PER_WARP)) * - DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); - const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; - const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); - - const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && - num_rows % (nvec_out * THREADS_PER_WARP) == 0; - - using ComputeType = fp32; - constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile * - (THREADS_PER_WARP + 1) * - sizeof(Vec); - constexpr size_t shared_size_dbias = cast_transpose_num_threads * - sizeof(Vec); - static_assert(shared_size_transpose >= shared_size_dbias); - using Param = TDBiasParam; - Param param; - param.input = reinterpret_cast(input.data.dptr); - param.output_t = reinterpret_cast(transposed_output->data.dptr); - param.scale_inv = reinterpret_cast(transposed_output->scale_inv.dptr); - param.workspace = reinterpret_cast(workspace->data.dptr); - - if (full_tile) { - cudaFuncSetAttribute(transpose_dbias_kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, - 100); - transpose_dbias_kernel - <<>>(param, row_length, num_rows, n_tiles); - } else { - cudaFuncSetAttribute(transpose_dbias_kernel_notaligned, - cudaFuncAttributePreferredSharedMemoryCarveout, - 100); - transpose_dbias_kernel_notaligned - <<>>(param, row_length, num_rows, n_tiles); - } - - reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, stream); - ); // NOLINT(*) - ); // NOLINT(*) + "T output must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{row_length}, "Wrong shape of DBias."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + dbias->data.dtype, BiasType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, Type, constexpr int type_size = sizeof(Type); + constexpr int nvec_in = desired_load_size / type_size; + constexpr int nvec_out = desired_store_size / type_size; + + if (workspace->data.dptr == nullptr) { + populate_transpose_dbias_workspace_config(input, workspace, nvec_out); + return; + } + + NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + const size_t n_tiles = + DIVUP(row_length, static_cast(nvec_in * THREADS_PER_WARP)) * + DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); + const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; + const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); + + const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && + num_rows % (nvec_out * THREADS_PER_WARP) == 0; + + using ComputeType = fp32; constexpr size_t shared_size_transpose = + cast_transpose_num_threads / n_warps_per_tile * + (THREADS_PER_WARP + 1) * sizeof(Vec); + constexpr size_t shared_size_dbias = + cast_transpose_num_threads * sizeof(Vec); + static_assert(shared_size_transpose >= shared_size_dbias); + using Param = TDBiasParam; Param param; + param.input = reinterpret_cast(input.data.dptr); + param.output_t = reinterpret_cast(transposed_output->data.dptr); + param.scale_inv = + reinterpret_cast(transposed_output->scale_inv.dptr); + param.workspace = reinterpret_cast(workspace->data.dptr); + + if (full_tile) { + cudaFuncSetAttribute(transpose_dbias_kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + transpose_dbias_kernel + <<>>( + param, row_length, num_rows, n_tiles); + } else { + cudaFuncSetAttribute(transpose_dbias_kernel_notaligned, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + transpose_dbias_kernel_notaligned + <<>>( + param, row_length, num_rows, n_tiles); + } + + reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, + stream);); // NOLINT(*) + ); // NOLINT(*) } - } // namespace transformer_engine -void nvte_fp8_transpose_dbias(const NVTETensor input, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream) { +void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_fp8_transpose_dbias); using namespace transformer_engine; - fp8_transpose_dbias(*reinterpret_cast(input), - reinterpret_cast(transposed_output), - reinterpret_cast(dbias), - reinterpret_cast(workspace), - stream); + fp8_transpose_dbias( + *reinterpret_cast(input), reinterpret_cast(transposed_output), + reinterpret_cast(dbias), reinterpret_cast(workspace), stream); } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 4744bc3aafa353fab7fafe46f1f19ff40038ef0a..8dd2b98ebfdd2e458c816fb6e5e82674b392ec98 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -5,9 +5,10 @@ ************************************************************************/ #include + #include "../common.h" -#include "../utils.cuh" #include "../util/vectorized_pointwise.h" +#include "../utils.cuh" namespace transformer_engine { @@ -15,9 +16,7 @@ namespace detail { struct Empty {}; -__device__ inline fp32 identity(fp32 value, const Empty&) { - return value; -} +__device__ inline fp32 identity(fp32 value, const Empty &) { return value; } struct DequantizeParam { const fp32 *scale_inv; @@ -29,83 +28,63 @@ __device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam ¶m) } // namespace detail -void fp8_quantize(const Tensor &input, - Tensor *output, - cudaStream_t stream) { +void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) { CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); - NVTE_CHECK(!is_fp8_dtype(input.data.dtype), - "Input must be in higher precision."); + NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision."); - NVTE_CHECK(is_fp8_dtype(output->data.dtype), - "Output must have FP8 type."); + NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - N, - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), N, {}, + stream);); // NOLINT(*) + ); // NOLINT(*) } -void fp8_dequantize(const Tensor &input, - Tensor *output, - cudaStream_t stream) { +void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); - NVTE_CHECK(is_fp8_dtype(input.data.dtype), - "Input must have FP8 type."); + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); - NVTE_CHECK(!is_fp8_dtype(output->data.dtype), - "Output must be in higher precision."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(OType); - detail::DequantizeParam p; - p.scale_inv = reinterpret_cast(input.scale_inv.dptr); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - nullptr, - nullptr, - N, - p, - stream); - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType); + detail::DequantizeParam p; + p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), nullptr, nullptr, N, p, + stream);); // NOLINT(*) + ); // NOLINT(*) } } // namespace transformer_engine -void nvte_fp8_quantize(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_fp8_quantize); using namespace transformer_engine; - fp8_quantize(*reinterpret_cast(input), - reinterpret_cast(output), + fp8_quantize(*reinterpret_cast(input), reinterpret_cast(output), stream); } -void nvte_fp8_dequantize(const NVTETensor input, - NVTETensor output, - cudaStream_t stream) { +void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_fp8_dequantize); using namespace transformer_engine; - fp8_dequantize(*reinterpret_cast(input), - reinterpret_cast(output), + fp8_dequantize(*reinterpret_cast(input), reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index f04d3861b84afd13081d57f0729f45fa13c87e4f..3dff6434c1e05e3e8f46ba7ce0bd3eb40e355232 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include + #include #include "../common.h" @@ -40,27 +41,21 @@ class Library { #endif // _WIN32 or _WIN64 or __WINDOW__ } - Library(const Library&) = delete; // move-only + Library(const Library &) = delete; // move-only - Library(Library&& other) noexcept { - swap(*this, other); - } + Library(Library &&other) noexcept { swap(*this, other); } - Library& operator=(Library other) noexcept { + Library &operator=(Library other) noexcept { // Copy-and-swap idiom swap(*this, other); return *this; } - friend void swap(Library& first, Library& second) noexcept; + friend void swap(Library &first, Library &second) noexcept; - void *get() noexcept { - return handle_; - } + void *get() noexcept { return handle_; } - const void *get() const noexcept { - return handle_; - } + const void *get() const noexcept { return handle_; } /*! \brief Get pointer corresponding to symbol in shared library */ void *get_symbol(const char *symbol) { @@ -78,13 +73,13 @@ class Library { void *handle_ = nullptr; }; -void swap(Library& first, Library& second) noexcept { +void swap(Library &first, Library &second) noexcept { using std::swap; swap(first.handle_, second.handle_); } /*! \brief Lazily-initialized shared library for CUDA driver */ -Library& cuda_driver_lib() { +Library &cuda_driver_lib() { #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) constexpr char lib_name[] = "nvcuda.dll"; #else @@ -98,9 +93,7 @@ Library& cuda_driver_lib() { namespace cuda_driver { -void *get_symbol(const char *symbol) { - return cuda_driver_lib().get_symbol(symbol); -} +void *get_symbol(const char *symbol) { return cuda_driver_lib().get_symbol(symbol); } } // namespace cuda_driver diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 65a4c9e882c51653d08613b22c3a3a681850c33e..9dc1114580742e5f684988978c0e8f1ee8fcc53b 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -7,10 +7,10 @@ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ -#include - #include +#include + #include "../common.h" #include "../util/string.h" @@ -35,7 +35,7 @@ void *get_symbol(const char *symbol); template inline CUresult call(const char *symbol, ArgTs... args) { using FuncT = CUresult(ArgTs...); - FuncT *func = reinterpret_cast(get_symbol(symbol)); + FuncT *func = reinterpret_cast(get_symbol(symbol)); return (*func)(args...); } @@ -43,23 +43,20 @@ inline CUresult call(const char *symbol, ArgTs... args) { } // namespace transformer_engine -#define NVTE_CHECK_CUDA_DRIVER(expr) \ - do { \ - const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \ - if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \ - const char *desc_NVTE_CHECK_CUDA_DRIVER; \ - ::transformer_engine::cuda_driver::call( \ - "cuGetErrorString", \ - status_NVTE_CHECK_CUDA_DRIVER, \ - &desc_NVTE_CHECK_CUDA_DRIVER); \ - NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \ - } \ +#define NVTE_CHECK_CUDA_DRIVER(expr) \ + do { \ + const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \ + if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \ + const char *desc_NVTE_CHECK_CUDA_DRIVER; \ + ::transformer_engine::cuda_driver::call("cuGetErrorString", status_NVTE_CHECK_CUDA_DRIVER, \ + &desc_NVTE_CHECK_CUDA_DRIVER); \ + NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \ + } \ } while (false) -#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \ - do { \ - NVTE_CHECK_CUDA_DRIVER( \ - ::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \ +#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \ + do { \ + NVTE_CHECK_CUDA_DRIVER(::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \ } while (false) #endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 54d3d2720b829287a251594eb59b404a1a858eba..5728ef557aa950b6ed919e67f56f17a20f339828 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -4,12 +4,13 @@ * See LICENSE for license information. ************************************************************************/ +#include "../util/cuda_runtime.h" + #include #include #include "../common.h" #include "../util/cuda_driver.h" -#include "../util/cuda_runtime.h" #include "../util/system.h" namespace transformer_engine { @@ -24,7 +25,7 @@ namespace { } // namespace int num_devices() { - auto query_num_devices = [] () -> int { + auto query_num_devices = []() -> int { int count; NVTE_CHECK_CUDA(cudaGetDeviceCount(&count)); return count; @@ -54,10 +55,10 @@ int sm_arch(int device_id) { device_id = current_device(); } NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); - auto init = [&] () { + auto init = [&]() { cudaDeviceProp prop; NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id)); - cache[device_id] = 10*prop.major + prop.minor; + cache[device_id] = 10 * prop.major + prop.minor; }; std::call_once(flags[device_id], init); return cache[device_id]; @@ -70,7 +71,7 @@ int sm_count(int device_id) { device_id = current_device(); } NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); - auto init = [&] () { + auto init = [&]() { cudaDeviceProp prop; NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id)); cache[device_id] = prop.multiProcessorCount; @@ -90,12 +91,11 @@ const std::string &include_directory(bool required) { if (need_to_check_env) { // Search for CUDA headers in common paths using Path = std::filesystem::path; - std::vector> search_paths = { - {"NVTE_CUDA_INCLUDE_DIR", ""}, - {"CUDA_HOME", ""}, - {"CUDA_DIR", ""}, - {"", string_path_cuda_include}, - {"", "/usr/local/cuda"}}; + std::vector> search_paths = {{"NVTE_CUDA_INCLUDE_DIR", ""}, + {"CUDA_HOME", ""}, + {"CUDA_DIR", ""}, + {"", string_path_cuda_include}, + {"", "/usr/local/cuda"}}; for (auto &[env, p] : search_paths) { if (p.empty()) { p = getenv(env.c_str()); @@ -131,10 +131,11 @@ const std::string &include_directory(bool required) { message += p; } } - message += (". " - "Specify path to CUDA Toolkit headers " - "with NVTE_CUDA_INCLUDE_DIR " - "or disable NVRTC support with NVTE_DISABLE_NVRTC=1."); + message += + (". " + "Specify path to CUDA Toolkit headers " + "with NVTE_CUDA_INCLUDE_DIR " + "or disable NVRTC support with NVTE_DISABLE_NVRTC=1."); NVTE_ERROR(message); } need_to_check_env = false; diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index f4188336dd01f3af47fc514a4b5df04f09557679..b6b4c4161044af7c6b2c2eae040cc17d50b299bb 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -8,6 +8,7 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ #include + #include namespace transformer_engine { diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 227b6e96b5ca3c753280e74fa30beb9a645f32a6..7972db31626d061f713167d863ada4807b4256c5 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -7,83 +7,76 @@ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ -#include - #include #include #include #include +#include + #include "../util/string.h" -#define NVTE_ERROR(...) \ - do { \ - throw ::std::runtime_error( \ - ::transformer_engine::concat_strings( \ - __FILE__ ":", __LINE__, \ - " in function ", __func__, ": ", \ - ::transformer_engine::concat_strings(__VA_ARGS__))); \ +#define NVTE_ERROR(...) \ + do { \ + throw ::std::runtime_error(::transformer_engine::concat_strings( \ + __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ + ::transformer_engine::concat_strings(__VA_ARGS__))); \ } while (false) -#define NVTE_CHECK(expr, ...) \ - do { \ - if (!(expr)) { \ - NVTE_ERROR("Assertion failed: " #expr ". ", \ - ::transformer_engine::concat_strings(__VA_ARGS__)); \ - } \ +#define NVTE_CHECK(expr, ...) \ + do { \ + if (!(expr)) { \ + NVTE_ERROR("Assertion failed: " #expr ". ", \ + ::transformer_engine::concat_strings(__VA_ARGS__)); \ + } \ } while (false) -#define NVTE_CHECK_CUDA(expr) \ - do { \ - const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ - if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ - NVTE_ERROR("CUDA Error: ", \ - cudaGetErrorString(status_NVTE_CHECK_CUDA)); \ - } \ +#define NVTE_CHECK_CUDA(expr) \ + do { \ + const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ + if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ + NVTE_ERROR("CUDA Error: ", cudaGetErrorString(status_NVTE_CHECK_CUDA)); \ + } \ } while (false) -#define NVTE_CHECK_CUBLAS(expr) \ - do { \ - const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \ - if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \ - NVTE_ERROR("cuBLAS Error: ", \ - cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \ - } \ +#define NVTE_CHECK_CUBLAS(expr) \ + do { \ + const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \ + if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLAS Error: ", cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \ + } \ } while (false) -#define NVTE_CHECK_CUDNN(expr) \ - do { \ - const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \ - if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \ - NVTE_ERROR("cuDNN Error: ", \ - cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \ - ". " \ - "For more information, enable cuDNN error logging " \ - "by setting CUDNN_LOGERR_DBG=1 and " \ - "CUDNN_LOGDEST_DBG=stderr in the environment."); \ - } \ +#define NVTE_CHECK_CUDNN(expr) \ + do { \ + const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \ + if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \ + NVTE_ERROR("cuDNN Error: ", cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \ + ". " \ + "For more information, enable cuDNN error logging " \ + "by setting CUDNN_LOGERR_DBG=1 and " \ + "CUDNN_LOGDEST_DBG=stderr in the environment."); \ + } \ } while (false) -#define NVTE_CHECK_CUDNN_FE(expr) \ - do { \ - const auto error = (expr); \ - if (error.is_bad()) { \ - NVTE_ERROR("cuDNN Error: ", \ - error.err_msg, \ - ". " \ - "For more information, enable cuDNN error logging " \ - "by setting CUDNN_LOGERR_DBG=1 and " \ - "CUDNN_LOGDEST_DBG=stderr in the environment."); \ - } \ +#define NVTE_CHECK_CUDNN_FE(expr) \ + do { \ + const auto error = (expr); \ + if (error.is_bad()) { \ + NVTE_ERROR("cuDNN Error: ", error.err_msg, \ + ". " \ + "For more information, enable cuDNN error logging " \ + "by setting CUDNN_LOGERR_DBG=1 and " \ + "CUDNN_LOGDEST_DBG=stderr in the environment."); \ + } \ } while (false) -#define NVTE_CHECK_NVRTC(expr) \ - do { \ - const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \ - if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \ - NVTE_ERROR("NVRTC Error: ", \ - nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \ - } \ +#define NVTE_CHECK_NVRTC(expr) \ + do { \ + const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \ + if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \ + NVTE_ERROR("NVRTC Error: ", nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \ + } \ } while (false) #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 1736fe17b6c16e7f6994107df72794f33cd2fcf7..2625c97e7928279399764e0e4120f1bc62cf01ad 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -13,75 +13,73 @@ struct Empty {}; template __device__ inline OType gelu(const IType val, const Empty&) { - const float cval = val; - return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval))); + const float cval = val; + return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval))); } template __device__ inline OType dgelu(const IType val, const Empty&) { - const float cval = val; - const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); - return 0.5f * cval * ((1.f - tanh_out * tanh_out) * - (0.79788456f + 0.1070322243f * cval * cval)) + - 0.5f * (1.f + tanh_out); + const float cval = val; + const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); + return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) + + 0.5f * (1.f + tanh_out); } template __device__ inline OType sigmoid(const IType val, const Empty&) { - const float cval = val; - return 1.f / (1.f + expf(-cval)); + const float cval = val; + return 1.f / (1.f + expf(-cval)); } template __device__ inline OType dsigmoid(const IType val, const Empty& e) { - const float cval = val; - const float s = sigmoid(cval, e); - return s * (1.f - s); + const float cval = val; + const float s = sigmoid(cval, e); + return s * (1.f - s); } template __device__ inline OType qgelu(const IType val, const Empty& e) { - const float cval = val; - return cval * sigmoid(1.702f * cval, e); + const float cval = val; + return cval * sigmoid(1.702f * cval, e); } template __device__ inline OType dqgelu(const IType val, const Empty& e) { - const float cval = val; - return cval * dsigmoid(1.702f * cval, e) + - sigmoid(1.702f * cval, e); + const float cval = val; + return cval * dsigmoid(1.702f * cval, e) + sigmoid(1.702f * cval, e); } template __device__ inline OType silu(const IType val, const Empty& e) { - const float cval = val; - return cval * sigmoid(cval, e); + const float cval = val; + return cval * sigmoid(cval, e); } template __device__ inline OType dsilu(const IType val, const Empty& e) { - const float cval = val; - return cval * dsigmoid(cval, e) + sigmoid(cval, e); + const float cval = val; + return cval * dsigmoid(cval, e) + sigmoid(cval, e); } template -__device__ inline OType relu(IType value, const Empty &) { - return fmaxf(value, 0.f); +__device__ inline OType relu(IType value, const Empty&) { + return fmaxf(value, 0.f); } template -__device__ inline OType drelu(IType value, const Empty &) { - return value > 0.f ? 1.f : 0.f; +__device__ inline OType drelu(IType value, const Empty&) { + return value > 0.f ? 1.f : 0.f; } template -__device__ inline OType srelu(IType value, const Empty &) { - return value > 0 ? value * value : 0.f; +__device__ inline OType srelu(IType value, const Empty&) { + return value > 0 ? value * value : 0.f; } template -__device__ inline OType dsrelu(IType value, const Empty &) { - return fmaxf(2.f * value, 0.f); +__device__ inline OType dsrelu(IType value, const Empty&) { + return fmaxf(2.f * value, 0.f); } } // namespace transformer_engine diff --git a/transformer_engine/common/util/rtc.cpp b/transformer_engine/common/util/rtc.cpp index 16ac507f8a62a29f67128e2a04dc31d4c39fb867..c03654bfc53b103b3d87b5c5fc67a62f788cd877 100644 --- a/transformer_engine/common/util/rtc.cpp +++ b/transformer_engine/common/util/rtc.cpp @@ -4,6 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include "../util/rtc.h" + #include #include #include @@ -13,8 +15,6 @@ #include "../util/string.h" #include "../util/system.h" -#include "../util/rtc.h" - namespace transformer_engine { namespace rtc { @@ -22,8 +22,8 @@ namespace rtc { namespace { // Strings with headers for RTC kernels -#include "string_code_utils_cuh.h" #include "string_code_util_math_h.h" +#include "string_code_utils_cuh.h" /*! \brief Latest compute capability that NVRTC supports * @@ -56,29 +56,25 @@ bool is_enabled() { } Kernel::Kernel(std::string mangled_name, std::string compiled_code) - : mangled_name_{std::move(mangled_name)} - , compiled_code_{std::move(compiled_code)} - , modules_(cuda::num_devices(), null_module) - , functions_(cuda::num_devices(), null_function) - , init_flags_{std::make_unique>(cuda::num_devices())} { -} + : mangled_name_{std::move(mangled_name)}, + compiled_code_{std::move(compiled_code)}, + modules_(cuda::num_devices(), null_module), + functions_(cuda::num_devices(), null_function), + init_flags_{std::make_unique>(cuda::num_devices())} {} Kernel::~Kernel() { - for (int device_id=0; device_id(modules_.size()); ++device_id) { + for (int device_id = 0; device_id < static_cast(modules_.size()); ++device_id) { // Unload CUDA modules if needed if (modules_[device_id] != null_module) { CUdevice device; CUcontext context; - if (cuda_driver::call("cuDeviceGet", &device, device_id) - != CUDA_SUCCESS) { + if (cuda_driver::call("cuDeviceGet", &device, device_id) != CUDA_SUCCESS) { continue; } - if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device) - != CUDA_SUCCESS) { + if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device) != CUDA_SUCCESS) { continue; } - if (cuda_driver::call("cuCtxSetCurrent", context) - != CUDA_SUCCESS) { + if (cuda_driver::call("cuCtxSetCurrent", context) != CUDA_SUCCESS) { continue; } cuda_driver::call("cuModuleUnload", modules_[device_id]); @@ -87,9 +83,7 @@ Kernel::~Kernel() { } } -Kernel::Kernel(Kernel&& other) noexcept { - swap(*this, other); -} +Kernel::Kernel(Kernel&& other) noexcept { swap(*this, other); } Kernel& Kernel::operator=(Kernel other) noexcept { // Copy-and-swap idiom @@ -108,7 +102,7 @@ void swap(Kernel& first, Kernel& second) noexcept { CUfunction Kernel::get_function(int device_id) { // Load kernel on device if needed - auto load_on_device = [&] () { + auto load_on_device = [&]() { // Set driver context to proper device CUdevice device; CUcontext context; @@ -117,15 +111,11 @@ CUfunction Kernel::get_function(int device_id) { NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); // Load function into driver context - NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, - &modules_[device_id], - compiled_code_.c_str(), - 0, // numOptions - nullptr, // options - nullptr); // optionValues - NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction, - &functions_[device_id], - modules_[device_id], + NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, &modules_[device_id], compiled_code_.c_str(), + 0, // numOptions + nullptr, // options + nullptr); // optionValues + NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction, &functions_[device_id], modules_[device_id], mangled_name_.c_str()); // Reset driver context @@ -147,10 +137,8 @@ KernelManager& KernelManager::instance() { return instance_; } -void KernelManager::compile(const std::string &kernel_label, - const std::string &kernel_name, - const std::string &code, - const std::string &filename) { +void KernelManager::compile(const std::string& kernel_label, const std::string& kernel_name, + const std::string& code, const std::string& filename) { std::lock_guard lock_guard_(lock_); // Choose whether to compile to PTX or cubin @@ -162,9 +150,9 @@ void KernelManager::compile(const std::string &kernel_label, // Compilation flags std::vector opts = { #if NDEBUG == 0 - "-G", + "-G", #endif - "--std=c++17"}; + "--std=c++17"}; if (compile_ptx) { opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch)); } else { @@ -181,20 +169,14 @@ void KernelManager::compile(const std::string &kernel_label, constexpr int num_headers = 2; constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h}; constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"}; - NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, - code.c_str(), - filename.c_str(), - num_headers, - headers, - include_names)); + NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, code.c_str(), filename.c_str(), num_headers, + headers, include_names)); NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str())); - const nvrtcResult compile_result = nvrtcCompileProgram(program, - opts_ptrs.size(), - opts_ptrs.data()); + const nvrtcResult compile_result = + nvrtcCompileProgram(program, opts_ptrs.size(), opts_ptrs.data()); if (compile_result != NVRTC_SUCCESS) { // Display log if compilation failed - std::string log = concat_strings("NVRTC compilation log for ", - filename, ":\n"); + std::string log = concat_strings("NVRTC compilation log for ", filename, ":\n"); const size_t log_offset = log.size(); size_t log_size; NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size)); @@ -206,10 +188,8 @@ void KernelManager::compile(const std::string &kernel_label, } // Get mangled function name - const char *mangled_name; - NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program, - kernel_name.c_str(), - &mangled_name)); + const char* mangled_name; + NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program, kernel_name.c_str(), &mangled_name)); // Get compiled code std::string compiled_code; @@ -234,20 +214,19 @@ void KernelManager::compile(const std::string &kernel_label, NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program)); } -void KernelManager::set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config) { +void KernelManager::set_cache_config(const std::string& kernel_label, CUfunc_cache cache_config) { const int device_id = cuda::current_device(); const auto key = get_kernel_cache_key(kernel_label, device_id); - NVTE_CHECK(kernel_cache_.count(key) > 0, - "Attempted to configure RTC kernel before compilation"); + NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to configure RTC kernel before compilation"); kernel_cache_.at(key).set_function_cache_config(device_id, cache_config); } -bool KernelManager::is_compiled(const std::string &kernel_label, int device_id) const { +bool KernelManager::is_compiled(const std::string& kernel_label, int device_id) const { const auto key = get_kernel_cache_key(kernel_label, device_id); return kernel_cache_.count(key) > 0; } -std::string KernelManager::get_kernel_cache_key(const std::string &kernel_label, +std::string KernelManager::get_kernel_cache_key(const std::string& kernel_label, int device_id) const { return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label); } diff --git a/transformer_engine/common/util/rtc.h b/transformer_engine/common/util/rtc.h index 52dd1a7a61133acab6e24871346aa5ce410bd1c2..2c79d038b2529943ee9ecb0d0a581951335f5fa4 100644 --- a/transformer_engine/common/util/rtc.h +++ b/transformer_engine/common/util/rtc.h @@ -7,6 +7,10 @@ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_ +#include +#include +#include + #include #include #include @@ -14,10 +18,6 @@ #include #include -#include -#include -#include - #include "../common.h" #include "../util/cuda_driver.h" #include "../util/cuda_runtime.h" @@ -38,10 +38,10 @@ class Kernel { public: Kernel(std::string mangled_name, std::string compiled_code); ~Kernel(); - Kernel(const Kernel&) = delete; // move-only - Kernel(Kernel&&) noexcept; - Kernel& operator=(Kernel) noexcept; - friend void swap(Kernel& first, Kernel& second) noexcept; + Kernel(const Kernel &) = delete; // move-only + Kernel(Kernel &&) noexcept; + Kernel &operator=(Kernel) noexcept; + friend void swap(Kernel &first, Kernel &second) noexcept; /*! \brief Launch CUDA kernel * @@ -57,25 +57,12 @@ class Kernel { * \param[in] args Kernel arguments */ template - void launch(int device_id, - const dim3 grid_dim, - const dim3 block_dim, - unsigned int shared_mem_bytes, - cudaStream_t stream, - ArgTs &&... args) { - void* arg_ptrs[] = { const_cast(static_cast(&args))... }; - NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, - get_function(device_id), - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z, - shared_mem_bytes, - static_cast(stream), - arg_ptrs, - nullptr); + void launch(int device_id, const dim3 grid_dim, const dim3 block_dim, + unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) { + void *arg_ptrs[] = {const_cast(static_cast(&args))...}; + NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y, + grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes, + static_cast(stream), arg_ptrs, nullptr); } /*! \brief CUDA function for given CUDA device @@ -114,7 +101,7 @@ class Kernel { class KernelManager { public: /*! \brief Get singleton instance */ - static KernelManager& instance(); + static KernelManager &instance(); /*! \brief Compile CUDA kernel for current CUDA device * @@ -126,10 +113,8 @@ class KernelManager { * \param[in] filename Path to associate with source code, * primarily for debugging */ - void compile(const std::string &kernel_label, - const std::string &kernel_name, - const std::string &code, - const std::string &filename); + void compile(const std::string &kernel_label, const std::string &kernel_name, + const std::string &code, const std::string &filename); /*! \brief Whether CUDA kernel has been compiled for CUDA device * @@ -138,8 +123,7 @@ class KernelManager { * \return Whether kernel has been compiled */ - bool is_compiled(const std::string &kernel_label, - int device_id = -1) const; + bool is_compiled(const std::string &kernel_label, int device_id = -1) const; /*! \brief Launch CUDA kernel on current CUDA device * @@ -154,21 +138,12 @@ class KernelManager { * \param[in] args Kernel arguments */ template - void launch(const std::string &kernel_label, - const dim3 grid_dim, - const dim3 block_dim, - unsigned int shared_mem_bytes, - cudaStream_t stream, - ArgTs &&... args) { + void launch(const std::string &kernel_label, const dim3 grid_dim, const dim3 block_dim, + unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) { const int device_id = cuda::current_device(); const auto key = get_kernel_cache_key(kernel_label, device_id); - NVTE_CHECK(kernel_cache_.count(key) > 0, - "Attempted to launch RTC kernel before compilation"); - kernel_cache_.at(key).launch(device_id, - grid_dim, - block_dim, - shared_mem_bytes, - stream, + NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to launch RTC kernel before compilation"); + kernel_cache_.at(key).launch(device_id, grid_dim, block_dim, shared_mem_bytes, stream, std::forward(args)...); } @@ -189,8 +164,8 @@ class KernelManager { KernelManager() = default; ~KernelManager() = default; - KernelManager(const KernelManager&) = delete; - KernelManager& operator=(const KernelManager&) = delete; + KernelManager(const KernelManager &) = delete; + KernelManager &operator=(const KernelManager &) = delete; /*! \brief Construct key for kernel cache * @@ -199,8 +174,7 @@ class KernelManager { * * \return Key for kernel cache */ - std::string get_kernel_cache_key(const std::string &kernel_label, - int device_id) const; + std::string get_kernel_cache_key(const std::string &kernel_label, int device_id) const; }; } // namespace rtc diff --git a/transformer_engine/common/util/string.h b/transformer_engine/common/util/string.h index 5c117cdb4c3986e5ee76f659f73164b106805fe0..c0a2aa1077d6391a30c160b3ff10a9cfe8cee1a6 100644 --- a/transformer_engine/common/util/string.h +++ b/transformer_engine/common/util/string.h @@ -14,23 +14,18 @@ namespace transformer_engine { /*! \brief Convert to C-style or C++-style string */ -template ::value>::type> +template ::value>::type> inline std::string to_string_like(const T &val) { return std::to_string(val); } -inline const std::string& to_string_like(const std::string& val) noexcept { - return val; -} +inline const std::string &to_string_like(const std::string &val) noexcept { return val; } -constexpr const char *to_string_like(const char *val) noexcept { - return val; -} +constexpr const char *to_string_like(const char *val) noexcept { return val; } /*! \brief Convert arguments to strings and concatenate */ template -inline std::string concat_strings(const Ts &... args) { +inline std::string concat_strings(const Ts &...args) { std::string str; str.reserve(1024); // Assume strings are <1 KB (..., (str += to_string_like(args))); @@ -42,12 +37,9 @@ inline std::string concat_strings(const Ts &... args) { * This is a convenience wrapper around std::regex_replace. */ template -inline std::string regex_replace(const std::string &str, - const std::string &pattern, +inline std::string regex_replace(const std::string &str, const std::string &pattern, const T &replacement) { - return std::regex_replace(str, - std::regex(pattern), - to_string_like(replacement)); + return std::regex_replace(str, std::regex(pattern), to_string_like(replacement)); } } // namespace transformer_engine diff --git a/transformer_engine/common/util/system.cpp b/transformer_engine/common/util/system.cpp index 2508f8e1c017293c70030148b3682db3c19a81c3..0659061b475c2b73862639f5cbffc3fe89bbcd8a 100644 --- a/transformer_engine/common/util/system.cpp +++ b/transformer_engine/common/util/system.cpp @@ -4,6 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include "../util/system.h" + #include #include #include @@ -12,15 +14,14 @@ #include #include "../common.h" -#include "../util/system.h" namespace transformer_engine { namespace { template -inline typename std::enable_if::value, T>::type -getenv_helper(const char *variable, const T &default_value) { +inline typename std::enable_if::value, T>::type getenv_helper( + const char *variable, const T &default_value) { // Implementation for numeric types const char *env = std::getenv(variable); if (env == nullptr || env[0] == '\0') { @@ -34,8 +35,8 @@ getenv_helper(const char *variable, const T &default_value) { } template -inline typename std::enable_if::value, T>::type -getenv_helper(const char *variable, const T &default_value) { +inline typename std::enable_if::value, T>::type getenv_helper( + const char *variable, const T &default_value) { // Implementation for string-like types const char *env = std::getenv(variable); if (env == nullptr || env[0] == '\0') { @@ -47,13 +48,14 @@ getenv_helper(const char *variable, const T &default_value) { } // namespace -#define NVTE_INSTANTIATE_GETENV(T, default_value) \ - template <> T getenv(const char *variable, \ - const T &default_value_) { \ - return getenv_helper(variable, default_value_); \ - } \ - template <> T getenv(const char *variable) { \ - return getenv_helper(variable, default_value); \ +#define NVTE_INSTANTIATE_GETENV(T, default_value) \ + template <> \ + T getenv(const char *variable, const T &default_value_) { \ + return getenv_helper(variable, default_value_); \ + } \ + template <> \ + T getenv(const char *variable) { \ + return getenv_helper(variable, default_value); \ } NVTE_INSTANTIATE_GETENV(bool, false); NVTE_INSTANTIATE_GETENV(float, 0.f); @@ -69,8 +71,6 @@ NVTE_INSTANTIATE_GETENV(uint64_t, 0); NVTE_INSTANTIATE_GETENV(std::string, std::string()); NVTE_INSTANTIATE_GETENV(std::filesystem::path, std::filesystem::path()); -bool file_exists(const std::string &path) { - return static_cast(std::ifstream(path.c_str())); -} +bool file_exists(const std::string &path) { return static_cast(std::ifstream(path.c_str())); } } // namespace transformer_engine diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index ac7d796935f7441c2f713ed1152f1d7778d765e4..63ad1857cf958125e80eeb1a74f334224c98b4e8 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -8,6 +8,7 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_ #include + #include "../common.h" #include "../utils.cuh" @@ -30,15 +31,13 @@ class VectorizedStorage { } scratch_; inline __device__ VectorizedStorage() {} - inline __device__ VectorizedStorage(const VectorizedStorage& y2) { - scratch_.aligned = y2.scratch_.aligned; - } - inline __device__ VectorizedStorage(const LType &y2) { - scratch_.aligned = y2; + inline __device__ VectorizedStorage(const VectorizedStorage &y2) { + scratch_.aligned = y2.scratch_.aligned; } - inline __device__ VectorizedStorage& operator+=( - const VectorizedStorage& rhs) { - #pragma unroll + inline __device__ VectorizedStorage(const LType &y2) { scratch_.aligned = y2; } + inline __device__ VectorizedStorage &operator+=( + const VectorizedStorage &rhs) { +#pragma unroll for (int i = 0; i < nvec; ++i) { scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]); } @@ -58,7 +57,6 @@ struct select_const { using type = const LType; }; - /* \brief Helper class that enables accessing multiple values of type DType as 1 value of type LType. Additional aligned template argument allows performance optimizations if the pointer and the size of @@ -67,44 +65,37 @@ struct select_const { template class VectorizedAccessor { public: - using StorageType = VectorizedStorage::type, - nvec>; + using StorageType = VectorizedStorage::type, nvec>; using LType = typename select_const::type; StorageType storage_; - LType* aligned_ptr_; - DType* unaligned_ptr_; + LType *aligned_ptr_; + DType *unaligned_ptr_; int alignment_; size_t n_elems_; - inline __device__ VectorizedAccessor(DType* const ptr, const size_t size) { + inline __device__ VectorizedAccessor(DType *const ptr, const size_t size) { unaligned_ptr_ = ptr; if (aligned) { alignment_ = 0; - aligned_ptr_ = reinterpret_cast(ptr); + aligned_ptr_ = reinterpret_cast(ptr); n_elems_ = (size + nvec - 1) / nvec; } else { size_t ptr_as_number = reinterpret_cast(ptr); alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType); - aligned_ptr_ = reinterpret_cast(ptr - alignment_); + aligned_ptr_ = reinterpret_cast(ptr - alignment_); n_elems_ = (size + alignment_ + nvec - 1) / nvec; } } /* \brief Alignment of the input pointer in elements. */ - inline __device__ int alignment() const { - return alignment_; - } + inline __device__ int alignment() const { return alignment_; } /* \brief Access to separate elements. */ - inline __device__ DType* separate() { - return storage_.scratch_.separate; - } + inline __device__ DType *separate() { return storage_.scratch_.separate; } /* \brief Number of aligned elements that span the entire input tensor. */ - inline __device__ size_t num_aligned_elements() const { - return n_elems_; - } + inline __device__ size_t num_aligned_elements() const { return n_elems_; } /* \brief Load values from the input. \param id Aligned index of the element. @@ -119,7 +110,7 @@ class VectorizedAccessor { } else { #pragma unroll for (int j = 0; j < nvec; ++j) { - DType* ptr = reinterpret_cast(&(aligned_ptr_[id])) + j; + DType *ptr = reinterpret_cast(&(aligned_ptr_[id])) + j; if (reinterpret_cast(ptr) >= reinterpret_cast(unaligned_ptr_) && reinterpret_cast(ptr) < reinterpret_cast(unaligned_ptr_ + N)) { storage_.scratch_.separate[j] = *ptr; @@ -136,18 +127,16 @@ class VectorizedAccessor { template class VectorizedLoader : public VectorizedAccessor { public: - inline __device__ VectorizedLoader(const DType* ptr, const size_t N) : - VectorizedAccessor(ptr, N) { - } + inline __device__ VectorizedLoader(const DType *ptr, const size_t N) + : VectorizedAccessor(ptr, N) {} }; /* \brief Class used for vectorized writable access. */ template class VectorizedStorer : public VectorizedAccessor { public: - inline __device__ VectorizedStorer(DType* ptr, const size_t N) : - VectorizedAccessor(ptr, N) { - } + inline __device__ VectorizedStorer(DType *ptr, const size_t N) + : VectorizedAccessor(ptr, N) {} /* \brief Store values to the output. \param id Aligned index of the element. @@ -162,7 +151,7 @@ class VectorizedStorer : public VectorizedAccessor { } else { #pragma unroll for (int j = 0; j < nvec; ++j) { - DType* ptr = reinterpret_cast(&(this->aligned_ptr_[id])) + j; + DType *ptr = reinterpret_cast(&(this->aligned_ptr_[id])) + j; if (reinterpret_cast(ptr) >= reinterpret_cast(this->unaligned_ptr_) && reinterpret_cast(ptr) < reinterpret_cast(this->unaligned_ptr_ + N)) { *ptr = this->storage_.scratch_.separate[j]; @@ -175,34 +164,24 @@ class VectorizedStorer : public VectorizedAccessor { constexpr int unary_kernel_threads = 512; -template -__launch_bounds__(unary_kernel_threads) -__global__ void unary_kernel(const InputType *input, - OutputType *output, - const ComputeType *scale, - ComputeType *amax, - Param p, - const size_t N, - const size_t num_aligned_elements) { +template +__launch_bounds__(unary_kernel_threads) __global__ + void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale, + ComputeType *amax, Param p, const size_t N, + const size_t num_aligned_elements) { VectorizedLoader loader(input, N); VectorizedStorer storer(output, N); ComputeType max = 0; ComputeType s = 0; if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; + if (scale != nullptr) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; const size_t M = num_aligned_elements; - for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { loader.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { @@ -224,43 +203,32 @@ __global__ void unary_kernel(const InputType *input, max = reduce_max(max, warp_id); if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } } } -template -__launch_bounds__(unary_kernel_threads) -__global__ void unary_grad_kernel(const InputTypeGrad *grad, - const InputType *input, - OutputType *output, - const ComputeType *scale, - ComputeType *amax, - Param p, - const size_t N, - const size_t num_aligned_elements) { +__launch_bounds__(unary_kernel_threads) __global__ + void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output, + const ComputeType *scale, ComputeType *amax, Param p, const size_t N, + const size_t num_aligned_elements) { VectorizedLoader loader(input, N); VectorizedLoader grad_loader(grad, N); VectorizedStorer storer(output, N); ComputeType max = 0; ComputeType s = 0; if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; + if (scale != nullptr) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; const size_t M = num_aligned_elements; - for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { loader.load(tid, N); grad_loader.load(tid, N); #pragma unroll @@ -284,25 +252,25 @@ __global__ void unary_grad_kernel(const InputTypeGrad *grad, max = reduce_max(max, warp_id); if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } } } namespace { -inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim, - const int nvec, const int size) { +inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim, const int nvec, + const int size) { size_t ptr_as_number = reinterpret_cast(ptr); int alignment = (ptr_as_number % (nvec * size)) / size; return DIVUP(lead_dim + alignment, static_cast(nvec)); } enum class Alignment { - SAME_ALIGNED, // All tensors aligned + SAME_ALIGNED, // All tensors aligned SAME_UNALIGNED, // All tensors have the same misalignment - DIFFERENT // Tensors have different alignment + DIFFERENT // Tensors have different alignment }; inline int CalcAlignment(const void *ptr, const int size) { @@ -317,10 +285,7 @@ inline int CalcAlignment(const void *ptr, const int size) { \param ptrs Inputs and Outputs to the operator. */ template -Alignment CheckAlignment(const size_t lead_dim, - const int nvec, - const T... ptrs - ) { +Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) { std::vector alignments; alignments.reserve(sizeof...(T)); @@ -328,13 +293,12 @@ Alignment CheckAlignment(const size_t lead_dim, (..., alignments.push_back(CalcAlignment(ptrs, sizeof(*ptrs) * nvec))); bool all_same = std::all_of(alignments.cbegin(), alignments.cend(), - [alignments](int val) {return val == alignments.front();}); + [alignments](int val) { return val == alignments.front(); }); if (!all_same) { return Alignment::DIFFERENT; } - if (alignments.front() == 0 && - lead_dim % nvec == 0) { + if (alignments.front() == 0 && lead_dim % nvec == 0) { // all alignment are 0 return Alignment::SAME_ALIGNED; } else { @@ -344,22 +308,15 @@ Alignment CheckAlignment(const size_t lead_dim, } // namespace -template -void VectorizedUnaryKernelLauncher(const InputType *input, - OutputType *output, - const fp32 *scale, - fp32 *amax, - const size_t N, - const Param params, +void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, + fp32 *amax, const size_t N, const Param params, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); - size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, - sizeof(InputType)); + size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; @@ -376,32 +333,23 @@ void VectorizedUnaryKernelLauncher(const InputType *input, break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_kernel<1, true, fp32, Param, OP><<>>( - input, output, scale, amax, params, N, N); + unary_kernel<1, true, fp32, Param, OP> + <<>>(input, output, scale, amax, params, N, N); break; } } } } -template -void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, - const InputType *input, - OutputType *output, - const fp32 *scale, - fp32 *amax, - const size_t N, - const Param params, - cudaStream_t stream) { +template +void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, + OutputType *output, const fp32 *scale, fp32 *amax, + const size_t N, const Param params, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, grad, output); - size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, - sizeof(InputType)); + size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; @@ -418,33 +366,23 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_grad_kernel<1, true, fp32, Param, OP><<>>( - grad, input, output, scale, amax, params, N, N); + unary_grad_kernel<1, true, fp32, Param, OP> + <<>>(grad, input, output, scale, amax, params, N, N); break; } } } } -template -__launch_bounds__(unary_kernel_threads) -__global__ void gated_act_kernel(const InputType *input, - OutputType *output, - const ComputeType *scale, - ComputeType *amax, - const size_t m, - const size_t n, - const Param p, - const size_t num_aligned_elements) { +__launch_bounds__(unary_kernel_threads) __global__ + void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale, + ComputeType *amax, const size_t m, const size_t n, const Param p, + const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; - for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; VectorizedLoader loader0(input + id_y * n * 2, n); @@ -453,7 +391,7 @@ __global__ void gated_act_kernel(const InputType *input, ComputeType max = 0; ComputeType s = 0; if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; + if (scale != nullptr) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; @@ -478,26 +416,18 @@ __global__ void gated_act_kernel(const InputType *input, max = reduce_max(max, warp_id); if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } } } } -template -void GatedActivationKernelLauncher(const InputType *input, - OutputType *output, - const fp32 *scale, - fp32 *amax, - const size_t m, - const size_t n, - const Param &p, +void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, + fp32 *amax, const size_t m, const size_t n, const Param &p, cudaStream_t stream) { if (m != 0 && n != 0) { size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType)); @@ -509,44 +439,34 @@ void GatedActivationKernelLauncher(const InputType *input, switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { case Alignment::SAME_ALIGNED: gated_act_kernel - <<>>( - input, output, scale, amax, m, n, p, num_aligned_elements); + <<>>(input, output, scale, amax, m, n, p, + num_aligned_elements); break; case Alignment::SAME_UNALIGNED: gated_act_kernel - <<>>( - input, output, scale, amax, m, n, p, num_aligned_elements); + <<>>(input, output, scale, amax, m, n, p, + num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize gated_act_kernel<1, true, ComputeType, Param, Activation> - <<>>( - input, output, scale, amax, m, n, p, n); + <<>>(input, output, scale, amax, m, n, p, n); break; } } } } -template -__launch_bounds__(unary_kernel_threads) -__global__ void dgated_act_kernel(const InputType *grad, - const InputType *input, - OutputType *output, - const size_t m, - const size_t n, - const Param p, - const size_t num_aligned_elements) { +__launch_bounds__(unary_kernel_threads) __global__ + void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output, + const size_t m, const size_t n, const Param p, + const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; - for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; VectorizedLoader grad_loader(grad + id_y * n, n); @@ -576,23 +496,15 @@ __global__ void dgated_act_kernel(const InputType *grad, } } -template -void DGatedActivationKernelLauncher(const InputType *grad, - const InputType *input, - OutputType *output, - const size_t m, - const size_t n, - const Param &p, - cudaStream_t stream) { +void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input, + OutputType *output, const size_t m, const size_t n, + const Param &p, cudaStream_t stream) { if (m != 0 && n != 0) { - size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, - sizeof(InputType)); + size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; size_t num_blocks = DIVUP(num_aligned_elements * m, threads); constexpr size_t max_blocks = 65535; @@ -601,16 +513,18 @@ void DGatedActivationKernelLauncher(const InputType *grad, switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { case Alignment::SAME_ALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, num_aligned_elements); + <<>>(grad, input, output, m, n, p, + num_aligned_elements); break; case Alignment::SAME_UNALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, num_aligned_elements); + <<>>(grad, input, output, m, n, p, + num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> - <<>>(grad, input, output, m, n, p, n); + <<>>(grad, input, output, m, n, p, n); break; } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index ddd7fe0673997f96124a6e835cdd46797f17cc84..bcfc0c608dd7c626642370203441f3c662dd773c 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -31,47 +31,45 @@ constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ float2 operator+(const float2 & a, const float2 & b) { // NOLINT(*) - return {a.x + b.x, a.y + b.y}; +inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) + return {a.x + b.x, a.y + b.y}; } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void operator+=(float2 & a, const float2 & b) { // NOLINT(*) - a.x += b.x; - a.y += b.y; +inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*) + a.x += b.x; + a.y += b.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Sum { - inline __device__ Sum() {} - inline __device__ T operator()(const T &a, const T &b) const { - return a + b; - } + inline __device__ Sum() {} + inline __device__ T operator()(const T &a, const T &b) const { return a + b; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx) { - return __shfl_xor_sync(static_cast(-1), x, idx); +template +inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) { + return __shfl_xor_sync(static_cast(-1), x, idx); } -template<> -inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx) { - return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; +template <> +inline __device__ float2 warp_shuffle_xor(const float2 &x, uint32_t idx) { + return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)}; } -template -inline __device__ T warp_shuffle_down(const T & x, uint32_t idx) { - return __shfl_down_sync(static_cast(-1), x, idx); +template +inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) { + return __shfl_down_sync(static_cast(-1), x, idx); } -template<> -inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx) { - return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; +template <> +inline __device__ float2 warp_shuffle_down(const float2 &x, uint32_t idx) { + return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)}; } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,533 +79,517 @@ namespace transformer_engine { //////////////////////////////////////////////////////////////////////////////////////////////////// struct uint16 { - uint4 u; - uint4 v; - uint4 s; - uint4 t; + uint4 u; + uint4 v; + uint4 s; + uint4 t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct uint8 { - uint4 u; - uint4 v; + uint4 u; + uint4 v; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct BytesToType {}; -template<> +template <> struct BytesToType<64> { - using Type = uint16; - static_assert(sizeof(Type) == 64); + using Type = uint16; + static_assert(sizeof(Type) == 64); }; -template<> +template <> struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); + using Type = uint8; + static_assert(sizeof(Type) == 32); }; -template<> +template <> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); + using Type = uint4; + static_assert(sizeof(Type) == 16); }; -template<> +template <> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); + using Type = uint64_t; + static_assert(sizeof(Type) == 8); }; -template<> +template <> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); + using Type = uint32_t; + static_assert(sizeof(Type) == 4); }; -template<> +template <> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); + using Type = uint16_t; + static_assert(sizeof(Type) == 2); }; -template<> +template <> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); + using Type = uint8_t; + static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct TypeToVec2 {}; -template<> +template <> struct TypeToVec2 { - using Type = float2; + using Type = float2; }; -template<> +template <> struct TypeToVec2 { - using Type = half2; + using Type = half2; }; -template<> +template <> struct TypeToVec2 { - using Type = nv_bfloat162; + using Type = nv_bfloat162; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct CTDBiasDActParam { - using InputType = IType; - using InputType2 = IType2; - using OutputType = OType; - using ComputeType = CType; - const IType *input; - const IType2 *act_input; - OType *output_c; - OType *output_t; - const CType *scale_ptr; - CType *amax; - CType *scale_inv; - CType *workspace; - CType *warp_scales_inv; + using InputType = IType; + using InputType2 = IType2; + using OutputType = OType; + using ComputeType = CType; + const IType *input; + const IType2 *act_input; + OType *output_c; + OType *output_t; + const CType *scale_ptr; + CType *amax; + CType *scale_inv; + CType *workspace; + CType *warp_scales_inv; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Get { - template - static inline __device__ R of(const T &vec); + template + static inline __device__ R of(const T &vec); }; -template<> -template +template <> +template inline __device__ R Get<0>::of(const T &vec) { - return vec.x; + return vec.x; } -template<> -template +template <> +template inline __device__ R Get<1>::of(const T &vec) { - return vec.y; + return vec.y; } -template<> -template +template <> +template inline __device__ R Get<2>::of(const T &vec) { - return vec.z; + return vec.z; } -template<> -template +template <> +template inline __device__ R Get<3>::of(const T &vec) { - return vec.w; + return vec.w; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Converter{ - static inline __device__ Dst convert(const Src &from) { - return Dst(from); - } +template +struct Converter { + static inline __device__ Dst convert(const Src &from) { return Dst(from); } }; -template<> -struct Converter{ - static inline __device__ half2 convert(const float2 &x) { - return __float22half2_rn(x); - } +template <> +struct Converter { + static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); } }; -template<> -struct Converter{ - static inline __device__ nv_bfloat162 convert(const float2 &x) { +template <> +struct Converter { + static inline __device__ nv_bfloat162 convert(const float2 &x) { #if __CUDA_ARCH__ >= 800 - return __float22bfloat162_rn(x); + return __float22bfloat162_rn(x); #else - union { - nv_bfloat162 raw; - nv_bfloat16 elt[2]; - } tmp; - tmp.elt[0] = __float2bfloat16_rn(x.x); - tmp.elt[1] = __float2bfloat16_rn(x.y); - return tmp.raw; + union { + nv_bfloat162 raw; + nv_bfloat16 elt[2]; + } tmp; + tmp.elt[0] = __float2bfloat16_rn(x.x); + tmp.elt[1] = __float2bfloat16_rn(x.y); + return tmp.raw; #endif - } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Zeros{ - static inline __device__ T get() { - return T(0.f); - } +template +struct Zeros { + static inline __device__ T get() { return T(0.f); } }; -template<> -struct Zeros{ - static inline __device__ float2 get() { - return make_float2(0.f, 0.f); - } +template <> +struct Zeros { + static inline __device__ float2 get() { return make_float2(0.f, 0.f); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Vec { - enum { BYTES = NUM_ELT * sizeof(Elt_type) }; - - using Vec_type = typename BytesToType::Type; - using type = Elt_type; - - using Alias_type = union { - Vec_type vec; - Elt_type elt[NUM_ELT]; - }; + enum { BYTES = NUM_ELT * sizeof(Elt_type) }; - Alias_type data; + using Vec_type = typename BytesToType::Type; + using type = Elt_type; - template - inline __device__ void to(Vec &other) { // NOLINT(*) - #pragma unroll - for ( int it = 0; it < NUM_ELT; it++ ) { - other.data.elt[it] = S(this->data.elt[it]); - } - } + using Alias_type = union { + Vec_type vec; + Elt_type elt[NUM_ELT]; + }; - template - inline __device__ void assign(const Op &op) { - #pragma unroll - for ( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = op(it); - } - } + Alias_type data; - // Pointer is cast to vector type - inline __device__ void load_from(const void *base_ptr, size_t idx = 0) { - this->data.vec = static_cast(base_ptr)[idx]; - } - - // Pointer is cast to vector type - inline __device__ void store_to(void *base_ptr, size_t idx = 0) const { - static_cast(base_ptr)[idx] = this->data.vec; - } - - // Pointer is cast to element type. Loads min(count, NUM_ELT) - // elements and any remaining elements are set to zero. - inline __device__ void load_from_elts(const void *base_ptr, - size_t idx = 0, - size_t count = NUM_ELT) { - const Elt_type *elt_ptr = static_cast(base_ptr) + idx; - if ( count < NUM_ELT - || reinterpret_cast(elt_ptr) % BYTES != 0 ) { - #pragma unroll - for ( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = (it < count - ? elt_ptr[it] - : Elt_type(0.f)); - } - } else { - this->load_from(elt_ptr); - } + template + inline __device__ void to(Vec &other) { // NOLINT(*) +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + other.data.elt[it] = S(this->data.elt[it]); } + } - // Pointer is cast to element type. Stores min(count, NUM_ELT) - // elements. - inline __device__ void store_to_elts(void *base_ptr, - size_t idx = 0, - size_t count = NUM_ELT) const { - Elt_type *elt_ptr = static_cast(base_ptr) + idx; - if ( count < NUM_ELT - || reinterpret_cast(elt_ptr) % BYTES != 0 ) { - #pragma unroll - for ( int it = 0; it < NUM_ELT; it++ ) { - if ( it < count ) { - elt_ptr[it] = this->data.elt[it]; - } - } - } else { - this->store_to(elt_ptr); + template + inline __device__ void assign(const Op &op) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] = op(it); + } + } + + // Pointer is cast to vector type + inline __device__ void load_from(const void *base_ptr, size_t idx = 0) { + this->data.vec = static_cast(base_ptr)[idx]; + } + + // Pointer is cast to vector type + inline __device__ void store_to(void *base_ptr, size_t idx = 0) const { + static_cast(base_ptr)[idx] = this->data.vec; + } + + // Pointer is cast to element type. Loads min(count, NUM_ELT) + // elements and any remaining elements are set to zero. + inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0, + size_t count = NUM_ELT) { + const Elt_type *elt_ptr = static_cast(base_ptr) + idx; + if (count < NUM_ELT || reinterpret_cast(elt_ptr) % BYTES != 0) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f)); + } + } else { + this->load_from(elt_ptr); + } + } + + // Pointer is cast to element type. Stores min(count, NUM_ELT) + // elements. + inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0, + size_t count = NUM_ELT) const { + Elt_type *elt_ptr = static_cast(base_ptr) + idx; + if (count < NUM_ELT || reinterpret_cast(elt_ptr) % BYTES != 0) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + if (it < count) { + elt_ptr[it] = this->data.elt[it]; } + } + } else { + this->store_to(elt_ptr); } + } - inline __device__ void clear() { - #pragma unroll - for ( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = Elt_type(0.f); - } + inline __device__ void clear() { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] = Elt_type(0.f); } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct InterCTASync { - inline __device__ InterCTASync(int *barrier, - int group, - int num_groups, - int group_size) - : phase_counter_(0) - , b0_(barrier + group) // The barrier for this group of CTAs. - , b1_(barrier + group + num_groups) // The barrier for this group of CTAs. - , group_size_(group_size) { - // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! - } - - inline __device__ void spin_wait_(int *barrier, int step, int expected) { - asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); - for ( int found = -1; found != expected; ) { - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); - } - } - - inline __device__ void sync() { - // ALL THREADS MUST ENTER! - - // We switch barrier every iteration. - int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; - // We decrement every other iteration. - bool dec = phase_counter_ & 0x2; - int step = dec ? -1 : 1; - int expected = dec ? 0 : group_size_; - // There are only 4 phases: up/down for b0/b1. - phase_counter_ = (phase_counter_ + 1) & 0x3; - - if ( threadIdx.x == 0 ) { - spin_wait_(barrier, step, expected); - } - // CTA waits for thread 0 - __syncthreads(); - } + inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size) + : phase_counter_(0), + b0_(barrier + group) // The barrier for this group of CTAs. + , + b1_(barrier + group + num_groups) // The barrier for this group of CTAs. + , + group_size_(group_size) { + // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! + } + + inline __device__ void spin_wait_(int *barrier, int step, int expected) { + asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); + for (int found = -1; found != expected;) { + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); + } + } + + inline __device__ void sync() { + // ALL THREADS MUST ENTER! + + // We switch barrier every iteration. + int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; + // We decrement every other iteration. + bool dec = phase_counter_ & 0x2; + int step = dec ? -1 : 1; + int expected = dec ? 0 : group_size_; + // There are only 4 phases: up/down for b0/b1. + phase_counter_ = (phase_counter_ + 1) & 0x3; + + if (threadIdx.x == 0) { + spin_wait_(barrier, step, expected); + } + // CTA waits for thread 0 + __syncthreads(); + } - int phase_counter_; - int * b0_; - int * b1_; - int group_size_; + int phase_counter_; + int *b0_; + int *b1_; + int group_size_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Reducer : public Reducer { - using Base = Reducer; - using Type = typename Base::Type; - - enum { SMEM_BYTES = Base::SMEM_BYTES }; - - enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; - enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; - - // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) - enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + - WS_DATA_BYTES }; - - template - inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, - uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {} - - template - inline __device__ T allreduce(T data, const Op &op) { - data = Base::reduce(data, op); - // We switch workspace every iteration. - T * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - // Warp leaders 0 hold the CTA-local results. - if ( this->warp_n_ == 0 && this->lane_ == 0 ) { - workspace[bidn_] = data; - } - inter_cta_.sync(); - static_assert(CTAS_PER_ROW <= 32); - T total = Zeros::get(); - if (this->lane_ < CTAS_PER_ROW) { - total = workspace[this->lane_]; - } - total = Reducer::allreduce_(total, op); - - return total; - } - - InterCTASync inter_cta_; - - T * const w0_; - T * const w1_; - int bidn_; + using Base = Reducer; + using Type = typename Base::Type; + + enum { SMEM_BYTES = Base::SMEM_BYTES }; + + enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; + enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; + + // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) + enum { + WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES + }; + + template + inline __device__ Reducer(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), + inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW), + bidn_(bidn) // CTA id within the group. + , + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {} + + template + inline __device__ T allreduce(T data, const Op &op) { + data = Base::reduce(data, op); + // We switch workspace every iteration. + T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if (this->warp_n_ == 0 && this->lane_ == 0) { + workspace[bidn_] = data; + } + inter_cta_.sync(); + static_assert(CTAS_PER_ROW <= 32); + T total = Zeros::get(); + if (this->lane_ < CTAS_PER_ROW) { + total = workspace[this->lane_]; + } + total = Reducer::allreduce_(total, op); + + return total; + } + + InterCTASync inter_cta_; + + T *const w0_; + T *const w1_; + int bidn_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Reducer { - using Type = T; - enum { SMEM_BYTES = 0 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, - uint32_t warp_n, uint32_t lane, void * smem) - : warp_n_(warp_n) - , lane_(lane) {} - - template - static inline __device__ T allreduce_(T data, const Op &op) { - #pragma unroll - for ( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { - data = op(data, warp_shuffle_xor(data, it)); - } - return data; - } + using Type = T; + enum { SMEM_BYTES = 0 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - template - inline __device__ T allreduce(T data, const Op &op) { - return allreduce_(data, op); + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : warp_n_(warp_n), lane_(lane) {} + + template + static inline __device__ T allreduce_(T data, const Op &op) { +#pragma unroll + for (int it = 1; it < THREADS_PER_WARP; it *= 2) { + data = op(data, warp_shuffle_xor(data, it)); } + return data; + } - template - inline __device__ T reduce(T data, const Op &op) { - // only lane 0 holds the result! - #pragma unroll - for ( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { - data = op(data, warp_shuffle_down(data, it)); - } - return data; + template + inline __device__ T allreduce(T data, const Op &op) { + return allreduce_(data, op); + } + + template + inline __device__ T reduce(T data, const Op &op) { +// only lane 0 holds the result! +#pragma unroll + for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) { + data = op(data, warp_shuffle_down(data, it)); } - int warp_n_; - int lane_; + return data; + } + int warp_n_; + int lane_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Reducer : public Reducer { - using Base = Reducer; + using Base = Reducer; - using Type = T; + using Type = T; - enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - enum { THREADS_PER_WARP = 32 }; + enum { THREADS_PER_WARP = 32 }; - template - inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, - uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true), smem0_(&(static_cast(smem)[warp_m * WARPS_N])) - , smem1_(smem0_ + WARPS_M * WARPS_N) {} + template + inline __device__ Reducer(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), + use0_(true), + smem0_(&(static_cast(smem)[warp_m * WARPS_N])), + smem1_(smem0_ + WARPS_M * WARPS_N) {} - template - inline __device__ T allreduce(T data, const Op & op) { - T * const smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - data = Base::reduce(data, op); - if ( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - #pragma unroll - for ( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - return out; + template + inline __device__ T allreduce(T data, const Op &op) { + T *const smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + data = Base::reduce(data, op); + if (this->lane_ == 0) { + smem[this->warp_n_] = data; } + __syncthreads(); + T out = Zeros::get(); +#pragma unroll + for (int it = 0; it < WARPS_N; it++) { + out = op(out, smem[it]); + } + return out; + } - template - inline __device__ T reduce(T data, const Op &op) { - T * const smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // only intra-CTA group leader holds the result! - data = Base::reduce(data, op); - if ( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - if ( this->warp_n_ == 0 && this->lane_ == 0 ) { - #pragma unroll - for ( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - } - return out; + template + inline __device__ T reduce(T data, const Op &op) { + T *const smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // only intra-CTA group leader holds the result! + data = Base::reduce(data, op); + if (this->lane_ == 0) { + smem[this->warp_n_] = data; } + __syncthreads(); + T out = Zeros::get(); + if (this->warp_n_ == 0 && this->lane_ == 0) { +#pragma unroll + for (int it = 0; it < WARPS_N; it++) { + out = op(out, smem[it]); + } + } + return out; + } - T * const smem0_; - T * const smem1_; - bool use0_; + T *const smem0_; + T *const smem1_; + bool use0_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct DynamicReducer : public Reducer { - using Base = Reducer; - using Type = typename Base::Type; - - template - inline __device__ DynamicReducer(const Params & params, - uint32_t bidm, uint32_t bidn, - uint32_t warp_m, uint32_t warp_n, - uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row) - , w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {} - - template - inline __device__ T allreduce(T data, const Op &op) { - // Trivial case - if (inter_cta_.group_size_ == 1) { - return Base::allreduce(data, op); - } - - data = Base::reduce(data, op); - // We switch workspace every iteration. - T * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - // Warp leaders 0 hold the CTA-local results. - if ( this->warp_n_ == 0 && this->lane_ == 0 ) { - workspace[bidn_] = data; - } - inter_cta_.sync(); - T total = Zeros::get(); - for ( int it = this->lane_; - it < inter_cta_.group_size_; - it += THREADS_PER_WARP ) { - total = op(total, workspace[it]); - } - total = Reducer::allreduce_(total, op); - - return total; - } - - template - inline __device__ T reduce(T data, const Op &op) { - return allreduce(data, op); - } - - InterCTASync inter_cta_; - - T * const w0_; - T * const w1_; - int bidn_; + using Base = Reducer; + using Type = typename Base::Type; + + template + inline __device__ DynamicReducer(const Params ¶ms, uint32_t bidm, uint32_t bidn, + uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), + inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row), + bidn_(bidn) // CTA id within the group. + , + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row), + w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {} + + template + inline __device__ T allreduce(T data, const Op &op) { + // Trivial case + if (inter_cta_.group_size_ == 1) { + return Base::allreduce(data, op); + } + + data = Base::reduce(data, op); + // We switch workspace every iteration. + T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if (this->warp_n_ == 0 && this->lane_ == 0) { + workspace[bidn_] = data; + } + inter_cta_.sync(); + T total = Zeros::get(); + for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) { + total = op(total, workspace[it]); + } + total = Reducer::allreduce_(total, op); + + return total; + } + + template + inline __device__ T reduce(T data, const Op &op) { + return allreduce(data, op); + } + + InterCTASync inter_cta_; + + T *const w0_; + T *const w1_; + int bidn_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -625,248 +607,249 @@ A detailed reference on the exact version implemented (with better numerical sta https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf */ -template -inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active) { // NOLINT(*) - // Assume at least leftmost is valid and - // init: step = next_pow2(num_active) / 2 (might get NaN otherwise) - int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); - - #pragma unroll - for ( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { - // Exchange - T n_b = warp_shuffle_down(n_a, step); - T m_b = warp_shuffle_down(m_a, step); - T m2_b = warp_shuffle_down(m2_a, step); - - // Update - const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. - // Might have different n per thread, otherwise this would simplify :( - const T rn_ab = 1.f / n_ab; - const T delta = m_a - m_b; - const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; - const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; - - n_a = n_ab; - m_a = m_ab; - m2_a = m2_ab; - } - // Intra-warp broadcast (only lane 0 has valid stats). - m_a = __shfl_sync(static_cast(-1), m_a, 0); - m2_a = __shfl_sync(static_cast(-1), m2_a, 0); +template +inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, + int num_active) { // NOLINT(*) + // Assume at least leftmost is valid and + // init: step = next_pow2(num_active) / 2 (might get NaN otherwise) + int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + +#pragma unroll + for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { + // Exchange + T n_b = warp_shuffle_down(n_a, step); + T m_b = warp_shuffle_down(m_a, step); + T m2_b = warp_shuffle_down(m2_a, step); + + // Update + const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. + // Might have different n per thread, otherwise this would simplify :( + const T rn_ab = 1.f / n_ab; + const T delta = m_a - m_b; + const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + + n_a = n_ab; + m_a = m_ab; + m2_a = m2_ab; + } + // Intra-warp broadcast (only lane 0 has valid stats). + m_a = __shfl_sync(static_cast(-1), m_a, 0); + m2_a = __shfl_sync(static_cast(-1), m2_a, 0); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Stats { - // This could be done generically with the Reducer. But then we - // would have to exchange 3 instead of 2 fields. - - using BlockStats = Stats; - using stats_t = typename BlockStats::stats_t; - - enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; - - template - inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, - uint32_t warp_n, uint32_t lane, void * smem) - : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW) - , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) - , warp_n_(warp_n) - , lane_(lane) {} - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; - // TODO(ptredak) rn is not really needed here.. - constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); - stats_t block_stats = block_stats_.compute(elts, block_rn); - - stats_t * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - if ( warp_n_ == 0 && lane_ == 0 ) { - workspace[bidn_] = block_stats; - } + // This could be done generically with the Reducer. But then we + // would have to exchange 3 instead of 2 fields. - // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. - inter_cta_.sync(); + using BlockStats = Stats; + using stats_t = typename BlockStats::stats_t; - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; - // Assume CTA group size in N less than 32, such that we can finalize with a single warp. - static_assert(CTAS_PER_ROW <= 32); + template + inline __device__ Stats(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW), + block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), + bidn_(bidn) // CTA id within the group. + , + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), + warp_n_(warp_n), + lane_(lane) {} - // Every warp does the final reduction locally. - if ( lane_ < CTAS_PER_ROW ) { - stats_t result = workspace[lane_]; - n = ELTS_PER_ROW_PER_CTA; - m = transformer_engine::Get<0>::of(result); - m2 = transformer_engine::Get<1>::of(result); - } + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; + // TODO(ptredak) rn is not really needed here.. + constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); + stats_t block_stats = block_stats_.compute(elts, block_rn); - warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - return { m, m2 }; + if (warp_n_ == 0 && lane_ == 0) { + workspace[bidn_] = block_stats; } - InterCTASync inter_cta_; - BlockStats block_stats_; + // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. + inter_cta_.sync(); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume CTA group size in N less than 32, such that we can finalize with a single warp. + static_assert(CTAS_PER_ROW <= 32); + + // Every warp does the final reduction locally. + if (lane_ < CTAS_PER_ROW) { + stats_t result = workspace[lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = transformer_engine::Get<0>::of(result); + m2 = transformer_engine::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + + return {m, m2}; + } - stats_t * const w0_; - stats_t * const w1_; - int bidn_; - int warp_n_; - int lane_; + InterCTASync inter_cta_; + BlockStats block_stats_; + + stats_t *const w0_; + stats_t *const w1_; + int bidn_; + int warp_n_; + int lane_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Stats { - using WarpStats = Stats; - using stats_t = typename WarpStats::stats_t; - - enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; - - template - inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, - uint32_t warp_n, uint32_t lane, void * smem) - : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true) { - smem0_ = static_cast(smem) + warp_m * WARPS_N; - smem1_ = smem0_ + WARPS_M * WARPS_N; + using WarpStats = Stats; + using stats_t = typename WarpStats::stats_t; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + stats_t *smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // Compute warp local for all WARPS_N + constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); + stats_t warp_stats = warp_stats_.compute(elts, warp_rn); + + // Each warp warp leader stores its stats + const auto warp_n = warp_stats_.reducer_.warp_n_; + const auto lane = warp_stats_.reducer_.lane_; + if (lane == 0) { + smem[warp_n] = warp_stats; } + __syncthreads(); - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - stats_t * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // Compute warp local for all WARPS_N - constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); - stats_t warp_stats = warp_stats_.compute(elts, warp_rn); - - // Each warp warp leader stores its stats - const auto warp_n = warp_stats_.reducer_.warp_n_; - const auto lane = warp_stats_.reducer_.lane_; - if ( lane == 0 ) { - smem[warp_n] = warp_stats; - } - __syncthreads(); - - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume that there are less than 32 warps, such that we can finalize with a single warp - static_assert(WARPS_N <= 32); - if (lane < WARPS_N) { - stats_t result = smem[lane]; - n = N * THREADS_PER_WARP; - m = transformer_engine::Get<0>::of(result); - m2 = transformer_engine::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, WARPS_N); + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); - return { m, m2 }; + // Assume that there are less than 32 warps, such that we can finalize with a single warp + static_assert(WARPS_N <= 32); + if (lane < WARPS_N) { + stats_t result = smem[lane]; + n = N * THREADS_PER_WARP; + m = transformer_engine::Get<0>::of(result); + m2 = transformer_engine::Get<1>::of(result); } - WarpStats warp_stats_; - stats_t * smem0_; - stats_t * smem1_; - bool use0_; + + warp_chan_upd_dynamic(m, m2, n, WARPS_N); + + return {m, m2}; + } + WarpStats warp_stats_; + stats_t *smem0_; + stats_t *smem1_; + bool use0_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Stats { - using stats_t = typename TypeToVec2::Type; - // The simple Warp reducer. - using Reducer = Reducer; - - enum { SMEM_BYTES = 0 }; + using stats_t = typename TypeToVec2::Type; + // The simple Warp reducer. + using Reducer = Reducer; - template - inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, - uint32_t warp_n, uint32_t lane, void * smem) - : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {} + enum { SMEM_BYTES = 0 }; - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - auto sum = Sum(); + template + inline __device__ Stats(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {} - T m = Zeros::get(); - #pragma unroll - for ( int it = 0; it < N; it++ ) { - m += elts[it]; - } - m = reducer_.allreduce(m, sum) * rn; + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + auto sum = Sum(); - T m2 = Zeros::get(); - #pragma unroll - for ( int it = 0; it < N; it++ ) { - T diff = (elts[it] - m); - m2 += diff * diff; - } - m2 = reducer_.allreduce(m2, sum); + T m = Zeros::get(); +#pragma unroll + for (int it = 0; it < N; it++) { + m += elts[it]; + } + m = reducer_.allreduce(m, sum) * rn; - return {m, m2}; + T m2 = Zeros::get(); +#pragma unroll + for (int it = 0; it < N; it++) { + T diff = (elts[it] - m); + m2 += diff * diff; } + m2 = reducer_.allreduce(m2, sum); + + return {m, m2}; + } - Reducer reducer_; + Reducer reducer_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ __forceinline__ float warp_reduce_max(const float m) { - float tmp = m; + float tmp = m; #pragma unroll - for (int delta = num_elems/2; delta > 0; delta /= 2) { - const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta); - __builtin_assume(tmp >= 0); - __builtin_assume(other_m >= 0); - tmp = fmaxf(tmp, other_m); - } - return tmp; + for (int delta = num_elems / 2; delta > 0; delta /= 2) { + const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta); + __builtin_assume(tmp >= 0); + __builtin_assume(other_m >= 0); + tmp = fmaxf(tmp, other_m); + } + return tmp; } template __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) { - __shared__ float staging[num_warps]; - constexpr int warp_size = 32; - const float my_max = m; - const float my_warp_max = warp_reduce_max(my_max); - if (threadIdx.x % 32 == 0) { - staging[warpid] = my_warp_max; - } - __syncthreads(); - compute_t result = 0; - if (warpid == 0) { - const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; - result = warp_reduce_max(my_max); - } - return result; + __shared__ float staging[num_warps]; + constexpr int warp_size = 32; + const float my_max = m; + const float my_warp_max = warp_reduce_max(my_max); + if (threadIdx.x % 32 == 0) { + staging[warpid] = my_warp_max; + } + __syncthreads(); + compute_t result = 0; + if (warpid == 0) { + const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; + result = warp_reduce_max(my_max); + } + return result; } // Works only on positive values -__device__ __forceinline__ void atomicMaxFloat(float * addr, const float value) { - atomicMax(reinterpret_cast(addr), __float_as_int(value)); +__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) { + atomicMax(reinterpret_cast(addr), __float_as_int(value)); } // Works only on positive values -__device__ __forceinline__ void atomicMinFloat(float * addr, const float value) { - atomicMin(reinterpret_cast(addr), __float_as_int(value)); +__device__ __forceinline__ void atomicMinFloat(float *addr, const float value) { + atomicMin(reinterpret_cast(addr), __float_as_int(value)); } template -__device__ __forceinline__ void reciprocal(T * value_inv, const T value) { - *value_inv = 1 / value; +__device__ __forceinline__ void reciprocal(T *value_inv, const T value) { + *value_inv = 1 / value; } } // namespace transformer_engine diff --git a/transformer_engine/common/utils.py b/transformer_engine/common/utils.py index 339fa59f6caf696ce6c8d46d50f693c56e237c02..6fd9d141b4174c1cf3c7caf7eda34366a2ab6ec8 100644 --- a/transformer_engine/common/utils.py +++ b/transformer_engine/common/utils.py @@ -7,10 +7,11 @@ import warnings from enum import Enum warnings.filterwarnings( - "module", category=DeprecationWarning, module="transformer_engine.common.utils") + "module", category=DeprecationWarning, module="transformer_engine.common.utils" +) -class DeprecatedEnum: # pylint: disable=too-few-public-methods +class DeprecatedEnum: # pylint: disable=too-few-public-methods """DeprecatedEnum""" def __init__(self, enum_cls, msg): @@ -33,7 +34,7 @@ def deprecate_wrapper(obj, msg): if issubclass(obj, Enum): return DeprecatedEnum(obj, msg) - class DeprecatedCls(obj): # pylint: disable=too-few-public-methods + class DeprecatedCls(obj): # pylint: disable=too-few-public-methods """DeprecatedCls""" def __init__(self, *args, **kwargs): @@ -51,4 +52,5 @@ def deprecate_wrapper(obj, msg): return deprecated raise NotImplementedError( - f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}.") + f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}." + ) diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 7a2cf2690a41dcbdb247ece0e3b60bbf1bd98735..3200c8a019d0ba5ed1b8dfceabd68fd74ec72e6b 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -34,22 +34,24 @@ from .sharding import MajorShardingType, ShardingResource, ShardingType from ..common.utils import deprecate_wrapper from ..common.utils import DeprecatedEnum -MajorShardingType = DeprecatedEnum(MajorShardingType, - "MajorShardingType is deprecating in the near feature.") +MajorShardingType = DeprecatedEnum( + MajorShardingType, "MajorShardingType is deprecating in the near feature." +) ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.") ShardingResource = deprecate_wrapper( ShardingResource, - "ShardingResource is renamed to MeshResource, and will be removed in the near feature.") + "ShardingResource is renamed to MeshResource, and will be removed in the near feature.", +) __all__ = [ - 'NVTE_FP8_COLLECTION_NAME', - 'fp8_autocast', - 'update_collections', - 'get_delayed_scaling', - 'MeshResource', - 'MajorShardingType', - 'ShardingResource', - 'ShardingType', - 'flax', - 'praxis', + "NVTE_FP8_COLLECTION_NAME", + "fp8_autocast", + "update_collections", + "get_delayed_scaling", + "MeshResource", + "MajorShardingType", + "ShardingResource", + "ShardingType", + "flax", + "praxis", ] diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 951d177584bc5b25ecfc60804be4bd23d17c683e..e1ed2305d16e7c74ce4f930266bd849f9db9326c 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -22,6 +22,7 @@ class AttnBiasType(Enum): PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias)) POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias) """ + NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS @@ -34,6 +35,7 @@ class AttnMaskType(Enum): CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs. PADDING_CAUSAL_MASK: A combination of both causal and padding masks. """ + NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK @@ -42,6 +44,7 @@ class AttnMaskType(Enum): class QKVLayout(Enum): """QKV layout""" + BS3HD = NVTE_QKV_Layout.NVTE_BS3HD BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD @@ -54,72 +57,125 @@ def canonicalize_attn_mask_type(attn_mask_type: str): However, we will lease this limitation in the near feature. """ match attn_mask_type: - case 'no_mask': + case "no_mask": return AttnMaskType.NO_MASK - case 'padding': + case "padding": return AttnMaskType.PADDING_MASK - case 'causal': + case "causal": return AttnMaskType.CAUSAL_MASK - case 'padding_causal' | 'causal_padding': + case "padding_causal" | "causal_padding": return AttnMaskType.PADDING_CAUSAL_MASK - raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type=" - "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}") - - -def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type, attn_mask_type, - dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, - kv_max_seqlen, head_dim): + raise ValueError( + f"Unsupported {attn_mask_type=}, supported attn_mask_type=" + "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}" + ) + + +def is_fused_attn_kernel_available( + q_dtype, + kv_dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + dropout_probability, + q_num_heads, + kv_num_heads, + q_max_seqlen, + kv_max_seqlen, + head_dim, +): """ To check whether the fused attention kernel is supported """ - return tex.FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value, - attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads, - q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available() - - -def fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, - seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): + return tex.FusedAttnHelper( + q_dtype, + kv_dtype, + qkv_layout.value, + attn_bias_type.value, + attn_mask_type.value, + dropout_probability, + q_num_heads, + kv_num_heads, + q_max_seqlen, + kv_max_seqlen, + head_dim, + ).is_fused_attn_kernel_available() + + +def fused_attn_qkvpacked( + qkv: jnp.ndarray, + bias: jnp.ndarray | None, + mask: jnp.ndarray, + seed: jnp.ndarray | None, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): """ Fused attention with the qkvpacked inputs """ - output = _fused_attn_qkvpacked(qkv, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output = _fused_attn_qkvpacked( + qkv, + bias, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) return output @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) -def _fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, - seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): - - output, _ = _fused_attn_fwd_qkvpacked_rule(qkv, bias, mask, seed, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, - is_training) +def _fused_attn_qkvpacked( + qkv: jnp.ndarray, + bias: jnp.ndarray | None, + mask: jnp.ndarray, + seed: jnp.ndarray | None, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): + + output, _ = _fused_attn_fwd_qkvpacked_rule( + qkv, + bias, + mask, + seed, + attn_bias_type, + attn_mask_type, + scaling_factor, + dropout_probability, + is_training, + ) return output -def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, - seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): +def _fused_attn_fwd_qkvpacked_rule( + qkv: jnp.ndarray, + bias: jnp.ndarray | None, + mask: jnp.ndarray, + seed: jnp.ndarray | None, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: batch, seqlen, *_ = qkv.shape actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32) else: assert mask is not None mask = jnp.logical_not(mask) - actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) output, softmax_aux, rng_state = tex.fused_attn_fwd_qkvpacked( qkv, bias, @@ -129,29 +185,33 @@ def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, m attn_mask_type=attn_mask_type.value, scaling_factor=scaling_factor, dropout_probability=dropout_probability, - is_training=is_training) - output = checkpoint_name(output, 'context') - softmax_aux = checkpoint_name(softmax_aux, 'context') - rng_state = checkpoint_name(rng_state, 'context') + is_training=is_training, + ) + output = checkpoint_name(output, "context") + softmax_aux = checkpoint_name(softmax_aux, "context") + rng_state = checkpoint_name(rng_state, "context") return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen) -def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, ctx, dz): +def _fused_attn_bwd_qkvpacked_rule( + attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz +): qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx - grad_qkv, grad_bias = tex.fused_attn_bwd_qkvpacked(qkv, - bias, - softmax_aux, - rng_state, - output, - dz, - actual_seqlen, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + grad_qkv, grad_bias = tex.fused_attn_bwd_qkvpacked( + qkv, + bias, + softmax_aux, + rng_state, + output, + dz, + actual_seqlen, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None @@ -162,42 +222,79 @@ def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_facto _fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule) -def fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, - seed: jnp.ndarray, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): +def fused_attn_kvpacked( + q: jnp.ndarray, + kv: jnp.ndarray, + bias: jnp.ndarray, + mask: jnp.ndarray, + seed: jnp.ndarray, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): """ Fused attention with the kvpacked inputs """ - output = _fused_attn_kvpacked(q, - kv, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output = _fused_attn_kvpacked( + q, + kv, + bias, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) return output @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) -def _fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, - seed: jnp.ndarray, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): - - output, _ = _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, - is_training) +def _fused_attn_kvpacked( + q: jnp.ndarray, + kv: jnp.ndarray, + bias: jnp.ndarray, + mask: jnp.ndarray, + seed: jnp.ndarray, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): + + output, _ = _fused_attn_fwd_kvpacked_rule( + q, + kv, + bias, + mask, + seed, + attn_bias_type, + attn_mask_type, + scaling_factor, + dropout_probability, + is_training, + ) return output -def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): +def _fused_attn_fwd_kvpacked_rule( + q, + kv, + bias, + mask, + seed, + attn_bias_type, + attn_mask_type, + scaling_factor, + dropout_probability, + is_training, +): if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: batch, s_q, *_ = q.shape s_kv = kv.shape[1] @@ -206,9 +303,9 @@ def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_ else: assert mask is not None mask = jnp.logical_not(mask) - q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) if attn_mask_type == AttnMaskType.PADDING_MASK: - kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) else: # When mask is causal, the actual seqlen is not the last row, use max to find it kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) @@ -224,31 +321,35 @@ def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_ attn_mask_type=attn_mask_type.value, scaling_factor=scaling_factor, dropout_probability=dropout_probability, - is_training=is_training) - output = checkpoint_name(output, 'context') - softmax_aux = checkpoint_name(softmax_aux, 'context') - rng_state = checkpoint_name(rng_state, 'context') + is_training=is_training, + ) + output = checkpoint_name(output, "context") + softmax_aux = checkpoint_name(softmax_aux, "context") + rng_state = checkpoint_name(rng_state, "context") return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen) -def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, ctx, dz): +def _fused_attn_bwd_kvpacked_rule( + attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz +): q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx - grad_q, grad_kv, grad_bias = tex.fused_attn_bwd_kvpacked(q, - kv, - bias, - softmax_aux, - rng_state, - output, - dz, - q_actual_seqlen, - kv_actual_seqlen, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + grad_q, grad_kv, grad_bias = tex.fused_attn_bwd_kvpacked( + q, + kv, + bias, + softmax_aux, + rng_state, + output, + dz, + q_actual_seqlen, + kv_actual_seqlen, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None @@ -259,41 +360,84 @@ def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor _fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule) -def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, - seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, - scaling_factor: float, dropout_probability: float, is_training: bool): +def fused_attn( + q: jnp.ndarray, + k: jnp.ndarray, + v: jnp.ndarray, + bias: jnp.ndarray, + mask: jnp.ndarray, + seed: jnp.ndarray, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): """ Dot product attention with the seperated query, key, value """ - output = _fused_attn(q, - k, - v, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output = _fused_attn( + q, + k, + v, + bias, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) return output @partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10)) -def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, - mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float, - is_training: bool): - - output, _ = _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training) +def _fused_attn( + q: jnp.ndarray, + k: jnp.ndarray, + v: jnp.ndarray, + bias: jnp.ndarray, + mask: jnp.ndarray, + seed: jnp.ndarray, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): + + output, _ = _fused_attn_fwd_rule( + q, + k, + v, + bias, + mask, + seed, + attn_bias_type, + attn_mask_type, + scaling_factor, + dropout_probability, + is_training, + ) return output -def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): +def _fused_attn_fwd_rule( + q, + k, + v, + bias, + mask, + seed, + attn_bias_type, + attn_mask_type, + scaling_factor, + dropout_probability, + is_training, +): if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: batch, s_q, *_ = q.shape s_kv = k.shape[1] @@ -302,51 +446,65 @@ def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_ty else: assert mask is not None mask = jnp.logical_not(mask) - q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) if attn_mask_type == AttnMaskType.PADDING_MASK: - kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) else: # When mask is causal, the actual seqlen is not the last row, use max to find it kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) - output, softmax_aux, rng_state = tex.fused_attn_fwd(q, - k, - v, - bias, - q_actual_seqlen, - kv_actual_seqlen, - seed, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - output = checkpoint_name(output, 'context') - softmax_aux = checkpoint_name(softmax_aux, 'context') - rng_state = checkpoint_name(rng_state, 'context') - return output, (q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, - kv_actual_seqlen) - - -def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, - is_training, ctx, dz): + output, softmax_aux, rng_state = tex.fused_attn_fwd( + q, + k, + v, + bias, + q_actual_seqlen, + kv_actual_seqlen, + seed, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) + output = checkpoint_name(output, "context") + softmax_aux = checkpoint_name(softmax_aux, "context") + rng_state = checkpoint_name(rng_state, "context") + return output, ( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + q_actual_seqlen, + kv_actual_seqlen, + ) + + +def _fused_attn_bwd_rule( + attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz +): q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx - grad_q, grad_k, grad_v, grad_bias = tex.fused_attn_bwd(q, - k, - v, - bias, - softmax_aux, - rng_state, - output, - dz, - q_actual_seqlen, - kv_actual_seqlen, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + grad_q, grad_k, grad_v, grad_bias = tex.fused_attn_bwd( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + dz, + q_actual_seqlen, + kv_actual_seqlen, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index b8f74fd7c95a2c871c9b067a7c3809874b0304fc..f263f09b9ceb61a1b1f5b7be3db77c68c7deaa0e 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -20,25 +20,25 @@ from .misc import ( check_valid_batch_dims, jax_dtype_to_te_dtype, jax_dtype_to_ir_dtype, - get_padded_spec + get_padded_spec, ) from ..sharding import all_reduce_max_along_all_axes_except_PP -__all__ = ['act_lu', 'dact_lu', 'act_lu_fp8'] +__all__ = ["act_lu", "dact_lu", "act_lu_fp8"] ActivationEnum = { - ('gelu',): NVTE_Activation_Type.GELU, - ('gelu', 'linear'): NVTE_Activation_Type.GEGLU, - ('silu',): NVTE_Activation_Type.SILU, - ('silu', 'linear'): NVTE_Activation_Type.SWIGLU, - ('relu',): NVTE_Activation_Type.RELU, - ('relu', 'linear'): NVTE_Activation_Type.REGLU, - ('quick_gelu',): NVTE_Activation_Type.QGELU, - ('quick_gelu', 'linear'): NVTE_Activation_Type.QGEGLU, - ('squared_relu',): NVTE_Activation_Type.SRELU, - ('squared_relu', 'linear'): NVTE_Activation_Type.SREGLU, + ("gelu",): NVTE_Activation_Type.GELU, + ("gelu", "linear"): NVTE_Activation_Type.GEGLU, + ("silu",): NVTE_Activation_Type.SILU, + ("silu", "linear"): NVTE_Activation_Type.SWIGLU, + ("relu",): NVTE_Activation_Type.RELU, + ("relu", "linear"): NVTE_Activation_Type.REGLU, + ("quick_gelu",): NVTE_Activation_Type.QGELU, + ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU, + ("squared_relu",): NVTE_Activation_Type.SRELU, + ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU, } @@ -46,6 +46,7 @@ class ActLuPrimitive(BasePrimitive): """ Activation Forward Primitive """ + name = "te_act_lu" multiple_results = False inner_primitive = None @@ -61,7 +62,7 @@ class ActLuPrimitive(BasePrimitive): assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] x_shape = x_aval.shape - assert (x_shape[-2] == 2 or x_shape[-2] == 1) + assert x_shape[-2] == 2 or x_shape[-2] == 1 hidden_size = x_shape[-1] batch_shapes = x_shape[:-2] out_aval = core.raise_to_shaped(x_aval) @@ -92,7 +93,8 @@ class ActLuPrimitive(BasePrimitive): batch_size = reduce(operator.mul, ir_x_shape[:-2]) in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) opaque = transformer_engine_jax.pack_common_descriptor( - (batch_size, hidden_size), in_dtype, in_dtype, act_enum) + (batch_size, hidden_size), in_dtype, in_dtype, act_enum + ) out = custom_caller(ActLuPrimitive.name, args, opaque, False) @@ -111,8 +113,8 @@ class ActLuPrimitive(BasePrimitive): """ check_valid_batch_dims(batch_dims) assert ActLuPrimitive.outer_primitive is not None - inputs, = batched_args - inputs_bdim, = batch_dims + (inputs,) = batched_args + (inputs_bdim,) = batch_dims out_bdims = inputs_bdim return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims @@ -122,7 +124,7 @@ class ActLuPrimitive(BasePrimitive): """ act_lu infer_sharding_from_operands """ - del result_infos, act_enum # Unused. + del result_infos, act_enum # Unused. x_spec = get_padded_spec(arg_infos[0]) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) return out_sharding @@ -161,6 +163,7 @@ class DActLuPrimitive(BasePrimitive): """ Dgated ActLu Primitive """ + name = "te_dact_lu" multiple_results = False inner_primitive = None @@ -177,7 +180,7 @@ class DActLuPrimitive(BasePrimitive): assert x_aval.dtype == dtype for axis in range(len(dz_aval.shape) - 1): assert dz_aval.shape[axis] == x_aval.shape[axis] - assert (x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1) + assert x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1 i_hidden_size = dz_aval.shape[-1] g_hidden_size = x_aval.shape[-1] @@ -198,7 +201,7 @@ class DActLuPrimitive(BasePrimitive): ir_in_shape = ir_in_type.shape gi_type = ir.RankedTensorType(x.type) gi_shape = gi_type.shape -# assert ir_in_shape == gi_shape + # assert ir_in_shape == gi_shape for axis in range(len(ir_in_shape) - 1): assert ir_in_shape[axis] == gi_shape[axis] @@ -217,8 +220,9 @@ class DActLuPrimitive(BasePrimitive): args = CustomCallArgsWrapper(out_types, operands, operand_shapes) in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size), - in_dtype, in_dtype, act_enum) + opaque = transformer_engine_jax.pack_common_descriptor( + (ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum + ) out = custom_caller(DActLuPrimitive.name, args, opaque, False) @@ -251,7 +255,7 @@ class DActLuPrimitive(BasePrimitive): """ dact_lu infer_sharding_from_operands """ - del result_infos, act_enum # Unused. + del result_infos, act_enum # Unused. act_lu_out_spec = get_padded_spec(arg_infos[1]) dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec)) return dx_sharding @@ -275,8 +279,9 @@ class DActLuPrimitive(BasePrimitive): register_primitive(DActLuPrimitive) -def dact_lu(inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray, - activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray: +def dact_lu( + inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]] +) -> jnp.ndarray: """ dact_lu fusion wrapper Return dgated_act_lu(inputs) @@ -289,15 +294,17 @@ class ActLuFp8Primitive(BasePrimitive): """ ActLu FP8 Primitive """ + name = "te_act_lu_fp8" multiple_results = True - impl_static_args = (4, 5) #out_dtype, act_enum + impl_static_args = (4, 5) # out_dtype, act_enum inner_primitive = None outer_primitive = None @staticmethod - def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, - act_enum): # pylint: disable=unused-argument + def abstract( + x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, act_enum + ): # pylint: disable=unused-argument """ te_act_lu_p abstract """ @@ -309,7 +316,7 @@ class ActLuFp8Primitive(BasePrimitive): assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - assert (x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2) + assert x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2 hidden_size = x_aval.shape[-1] batch_shape = x_aval.shape[:-2] out_shape = (batch_shape) + (hidden_size,) @@ -349,17 +356,16 @@ class ActLuFp8Primitive(BasePrimitive): operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_common_descriptor(( - batch_size, hidden_size), + opaque = transformer_engine_jax.pack_common_descriptor( + (batch_size, hidden_size), jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), - act_enum) + act_enum, + ) - out = custom_caller(ActLuFp8Primitive.name, - args, - opaque, - False, - operand_output_aliases={1: 1}) + out = custom_caller( + ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} + ) return out @@ -369,12 +375,9 @@ class ActLuFp8Primitive(BasePrimitive): to describe implementation """ assert ActLuFp8Primitive.inner_primitive is not None - out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - act_enum=act_enum) + out, updated_amax = ActLuFp8Primitive.inner_primitive.bind( + x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum + ) return out, updated_amax @staticmethod @@ -388,9 +391,12 @@ class ActLuFp8Primitive(BasePrimitive): x_bdim, amax_bdim, _, _ = batch_dims out_bdims = x_bdim, amax_bdim - return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, - out_dtype=out_dtype, - act_enum=act_enum), out_bdims + return ( + ActLuFp8Primitive.outer_primitive.bind( + x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum + ), + out_bdims, + ) @staticmethod def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos): @@ -410,12 +416,9 @@ class ActLuFp8Primitive(BasePrimitive): out_shardings = (out_sharding, amax_sharding) def sharded_impl(x, amax, scale, scale_inv): - local_x, local_amax = ActLuFp8Primitive.impl(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - act_enum=act_enum) + local_x, local_amax = ActLuFp8Primitive.impl( + x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum + ) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) return local_x, global_updated_amax @@ -426,9 +429,14 @@ class ActLuFp8Primitive(BasePrimitive): register_primitive(ActLuFp8Primitive) -def act_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - out_dtype: jnp.dtype, activation_type: Sequence[Union[str, Callable]] - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +def act_lu_fp8( + x: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: jnp.dtype, + activation_type: Sequence[Union[str, Callable]], +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ act wrapper Return FP8(act_lu(x)) @@ -436,5 +444,6 @@ def act_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: (N, 2, H) for gated activations """ act_type_id = ActivationEnum[activation_type] - return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype, - act_enum = act_type_id) + return ActLuFp8Primitive.outer_primitive.bind( + x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id + ) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 0e16447609b0afc51663783aec7f7d56676f4f8e..01aa1141e10acd501e6a920b5d37ce53d8a0e15d 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -18,7 +18,7 @@ from transformer_engine.transformer_engine_jax import ( NVTE_Bias_Type, NVTE_Mask_Type, NVTE_QKV_Layout, - NVTE_Fused_Attn_Backend + NVTE_Fused_Attn_Backend, ) from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper @@ -26,7 +26,7 @@ from .misc import ( check_valid_batch_dims, jax_dtype_to_te_dtype, te_dtype_to_jax_dtype, - get_padded_spec + get_padded_spec, ) from ..sharding import ( all_reduce_sum_along_dp_fsdp, @@ -35,14 +35,15 @@ from ..sharding import ( ) -__all__ = ['FusedAttnHelper', - 'fused_attn_fwd_qkvpacked', - 'fused_attn_bwd_qkvpacked', - 'fused_attn_fwd_kvpacked', - 'fused_attn_bwd_kvpacked', - 'fused_attn_fwd', - 'fused_attn_bwd', - ] +__all__ = [ + "FusedAttnHelper", + "fused_attn_fwd_qkvpacked", + "fused_attn_bwd_qkvpacked", + "fused_attn_fwd_kvpacked", + "fused_attn_bwd_kvpacked", + "fused_attn_fwd", + "fused_attn_bwd", +] @dataclass(frozen=True) @@ -70,10 +71,18 @@ class FusedAttnHelper: def get_fused_attn_backend(self): """Get the fused attention kernel backend""" return transformer_engine_jax.get_fused_attn_backend( - jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype), - self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability, - self.q_num_heads, self.kv_num_heads, self.q_max_seqlen, self.kv_max_seqlen, - self.head_dim) + jax_dtype_to_te_dtype(self.q_dtype), + jax_dtype_to_te_dtype(self.kv_dtype), + self.qkv_layout, + self.attn_bias_type, + self.attn_mask_type, + self.dropout_probability, + self.q_num_heads, + self.kv_num_heads, + self.q_max_seqlen, + self.kv_max_seqlen, + self.head_dim, + ) @staticmethod def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): @@ -112,6 +121,7 @@ class _FusedAttnRNGStateChecker: so we have to emulate seed as two 32 bits array. The offset calculation is maintained in the backend. """ + rng_state_dtype: jnp.dtype = jnp.uint32 # (seed,) with internal dtype int64 seed_size: int = 2 @@ -133,7 +143,8 @@ class _FusedAttnRNGStateChecker: warnings.warn( f"Requested {seed.dtype=} is not available, and will be " f"casted to dtype {self.rng_state_dtype}. " - f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.") + "Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning." + ) seed = seed.astype(self.rng_state_dtype) assert seed.dtype == self.rng_state_dtype @@ -156,6 +167,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): """ Fused Attention Forward Primitive """ + name = "te_fused_attn_forward" multiple_results = True impl_static_args = (7, 8, 9, 10, 11, 12) @@ -163,9 +175,22 @@ class FusedAttnFwdPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, - kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, - qkv_layout, scaling_factor, dropout_probability, is_training): + def abstract( + q_aval, + k_aval, + v_aval, + bias_aval, + q_seqlen_or_cu_seqlen_aval, + kv_seqlen_or_cu_seqlen_aval, + seed_aval, + *, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + ): """ Fused attention fwd abstract """ @@ -176,16 +201,27 @@ class FusedAttnFwdPrimitive(BasePrimitive): assert q_dtype == k_dtype == v_dtype == bias_dtype assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype - batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + ) output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim) out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) # backend determines the softmax buffer shape/dtype - backend = FusedAttnHelper(q_dtype, k_dtype, qkv_layout, attn_bias_type, attn_mask_type, - dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, head_dim).get_fused_attn_backend() + backend = FusedAttnHelper( + q_dtype, + k_dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + dropout_probability, + attn_heads, + num_gqa_groups, + q_max_seqlen, + kv_max_seqlen, + head_dim, + ).get_fused_attn_backend() if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) @@ -194,7 +230,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: - raise ValueError(f'Unsupported {backend=}') + raise ValueError(f"Unsupported {backend=}") softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with @@ -215,11 +251,25 @@ class FusedAttnFwdPrimitive(BasePrimitive): # prepare for the active fused-attn backend input_batch = reduce(operator.mul, batch_shape) wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, - bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type, - attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training) - wkspace_aval = q_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + head_dim, + scaling_factor, + dropout_probability, + attn_bias_type, + attn_mask_type, + qkv_layout, + jax_dtype_to_te_dtype(q_aval.dtype), + is_training, + ) + wkspace_aval = q_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval @@ -228,13 +278,29 @@ class FusedAttnFwdPrimitive(BasePrimitive): """ Fused attention fwd outer primitive abstract """ - out_aval, softmax_aux_aval, rng_state_aval, _ = \ - FusedAttnFwdPrimitive.abstract(*args, **kwargs) + out_aval, softmax_aux_aval, rng_state_aval, _ = FusedAttnFwdPrimitive.abstract( + *args, **kwargs + ) return out_aval, softmax_aux_aval, rng_state_aval @staticmethod - def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, - attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training): + def lowering( + ctx, + q, + k, + v, + bias, + q_cu_seqlen, + kv_cu_seqlen, + seed, + *, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + ): """ Fused attention fwd lowering rules """ @@ -248,8 +314,9 @@ class FusedAttnFwdPrimitive(BasePrimitive): q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + ) input_batch = reduce(operator.mul, batch_shape) @@ -262,18 +329,45 @@ class FusedAttnFwdPrimitive(BasePrimitive): wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, - bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability, - attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training) + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + head_dim, + wkspace_aval.size, + scaling_factor, + dropout_probability, + attn_bias_type, + attn_mask_type, + qkv_layout, + jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + is_training, + ) out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) return out @staticmethod - def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, qkv_layout, - scaling_factor, dropout_probability, is_training): + def impl( + q, + k, + v, + bias, + q_seqlen, + kv_seqlen, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + ): assert FusedAttnFwdPrimitive.inner_primitive is not None q_cu_seqlen = generate_cu_seqlen(q_seqlen) @@ -292,29 +386,52 @@ class FusedAttnFwdPrimitive(BasePrimitive): qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, - is_training=is_training) + is_training=is_training, + ) return output, softmax_aux, rng_state @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout, - scaling_factor, dropout_probability, is_training): + def batcher( + batched_args, + batch_dims, + *, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + ): check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None q_bdim, *_, seed_bdim = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim - return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims + return ( + FusedAttnFwdPrimitive.outer_primitive.bind( + *batched_args, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ), + out_bdims, + ) @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): + def infer_sharding_from_operands( + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + mesh, + arg_infos, + result_infos, + ): del attn_bias_type, attn_mask_type, scaling_factor del dropout_probability, is_training, result_infos q_spec = get_padded_spec(arg_infos[0]) @@ -324,40 +441,55 @@ class FusedAttnFwdPrimitive(BasePrimitive): # q_spec = (...batch, q_seqlen, head, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)) + mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) + ) case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: # q_spec = (...batch, q_seqlen, head, hidden) # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4])) + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4]) + ) case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: # q_spec = (...batch, q_seqlen, head, hidden) # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])) + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]) + ) case _: raise ValueError(f"Unsupported {qkv_layout=}") rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) return (out_sharding, softmax_aux_sharding, rng_state_sharding) @staticmethod - def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, - is_training, mesh, arg_infos, result_infos): + def partition( + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + mesh, + arg_infos, + result_infos, + ): out_sharding = result_infos[0].sharding softmax_aux_sharding = result_infos[1].sharding - rng_state_sharding = seed_sharding = NamedSharding(mesh, - PartitionSpec(get_all_mesh_axes(), None)) + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial(FusedAttnFwdPrimitive.impl, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + impl = partial( + FusedAttnFwdPrimitive.impl, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) return mesh, impl, out_shardings, arg_shardings @@ -368,6 +500,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): """ Fused Attention Backward Primitive """ + name = "te_fused_attn_backward" multiple_results = True impl_static_args = (10, 11, 12, 13, 14, 15) @@ -375,9 +508,25 @@ class FusedAttnBwdPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, - doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type, - attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training): + def abstract( + q_aval, + k_aval, + v_aval, + bias_aval, + softmax_aux_aval, + rng_state_aval, + output_aval, + doutput_aval, + q_cu_seqlen_aval, + kv_cu_seqlen_aval, + *, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + ): """ Fused attention bwd abstract """ @@ -391,8 +540,9 @@ class FusedAttnBwdPrimitive(BasePrimitive): assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype - batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + ) if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 @@ -401,18 +551,31 @@ class FusedAttnBwdPrimitive(BasePrimitive): bias_batch = reduce(operator.mul, bias_batch_shape) input_batch = reduce(operator.mul, batch_shape) - wkspace_shape, wkspace_dtype = \ - transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, - bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type, - attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training) + wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + head_dim, + scaling_factor, + dropout_probability, + attn_bias_type, + attn_mask_type, + qkv_layout, + jax_dtype_to_te_dtype(q_aval.dtype), + is_training, + ) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype) dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype) dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - wkspace_aval = q_aval.update(shape=wkspace_shape, - dtype=te_dtype_to_jax_dtype(wkspace_dtype)) + wkspace_aval = q_aval.update( + shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype) + ) return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval @@ -421,19 +584,44 @@ class FusedAttnBwdPrimitive(BasePrimitive): """ Fused attention fwd outer primitive abstract """ - dq_aval, dk_aval, dv_aval, dbias_aval, _ = \ - FusedAttnBwdPrimitive.abstract(*args, **kwargs) + dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs) return dq_aval, dk_aval, dv_aval, dbias_aval @staticmethod - def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen, *, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, - dropout_probability, is_training): + def lowering( + ctx, + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + *, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + ): """ Fused attention bwd lowering rules """ operands = [ - q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, ] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ @@ -445,8 +633,9 @@ class FusedAttnBwdPrimitive(BasePrimitive): q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + ) input_batch = reduce(operator.mul, batch_shape) @@ -459,19 +648,48 @@ class FusedAttnBwdPrimitive(BasePrimitive): wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, - bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability, - attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training) + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + head_dim, + wkspace_aval.size, + scaling_factor, + dropout_probability, + attn_bias_type, + attn_mask_type, + qkv_layout, + jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + is_training, + ) out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) return out @staticmethod - def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen, - attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, - is_training): + def impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + ): assert FusedAttnBwdPrimitive.inner_primitive is not None q_cu_seqlen = generate_cu_seqlen(q_seqlen) @@ -493,29 +711,52 @@ class FusedAttnBwdPrimitive(BasePrimitive): qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, - is_training=is_training) + is_training=is_training, + ) return dq, dk, dv, dbias @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout, - scaling_factor, dropout_probability, is_training): + def batcher( + batched_args, + batch_dims, + *, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + ): check_valid_batch_dims(batch_dims) assert FusedAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, *_ = batch_dims out_bdims = q_bdim, k_bdim, v_bdim, q_bdim - return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims + return ( + FusedAttnBwdPrimitive.outer_primitive.bind( + *batched_args, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ), + out_bdims, + ) @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): + def infer_sharding_from_operands( + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + mesh, + arg_infos, + result_infos, + ): del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor del dropout_probability, is_training, result_infos q_spec = get_padded_spec(arg_infos[0]) @@ -529,8 +770,17 @@ class FusedAttnBwdPrimitive(BasePrimitive): return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) @staticmethod - def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, - is_training, mesh, arg_infos, result_infos): + def partition( + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + mesh, + arg_infos, + result_infos, + ): del result_infos q_spec = get_padded_spec(arg_infos[0]) k_spec = get_padded_spec(arg_infos[1]) @@ -543,8 +793,9 @@ class FusedAttnBwdPrimitive(BasePrimitive): arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen): + def sharded_impl( + q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen + ): local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( q, k, @@ -561,7 +812,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, - is_training=is_training) + is_training=is_training, + ) global_dbias = local_dbias if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) @@ -573,10 +825,17 @@ class FusedAttnBwdPrimitive(BasePrimitive): register_primitive(FusedAttnBwdPrimitive) -def fused_attn_fwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray, - seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, scaling_factor: float, - dropout_probability: float, is_training: bool): +def fused_attn_fwd_qkvpacked( + qkv: jnp.ndarray, + bias: jnp.ndarray, + seqlen: jnp.ndarray, + seed: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): """ Wrapper for TE self fused attention fwd Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 @@ -589,26 +848,37 @@ def fused_attn_fwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.nd bias = jnp.zeros(0, dtype=qkv.dtype) _not_used = jnp.zeros(0, qkv.dtype) - return FusedAttnFwdPrimitive.outer_primitive.bind(qkv, - _not_used, - _not_used, - bias, - seqlen, - seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - -def fused_attn_bwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray, - rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, - seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, scaling_factor: float, - dropout_probability: float, is_training: bool): + return FusedAttnFwdPrimitive.outer_primitive.bind( + qkv, + _not_used, + _not_used, + bias, + seqlen, + seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) + + +def fused_attn_bwd_qkvpacked( + qkv: jnp.ndarray, + bias: jnp.ndarray, + softmax_aux: jnp.ndarray, + rng_state: jnp.ndarray, + output: jnp.ndarray, + doutput: jnp.ndarray, + seqlen: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): """ Wrapper for TE self fused attention bwd Return the gradients of self fused attention with packed qkv input @@ -633,14 +903,24 @@ def fused_attn_bwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: j qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD, scaling_factor=scaling_factor, dropout_probability=dropout_probability, - is_training=is_training) + is_training=is_training, + ) return dqkv, dbias -def fused_attn_fwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, - q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray, - attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, - scaling_factor: float, dropout_probability: float, is_training: bool): +def fused_attn_fwd_kvpacked( + q: jnp.ndarray, + kv: jnp.ndarray, + bias: jnp.ndarray, + q_seqlen: jnp.ndarray, + kv_seqlen: jnp.ndarray, + seed: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): """ Wrapper for TE fused attention fwd with kvpacked inputs Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 @@ -652,26 +932,39 @@ def fused_attn_fwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, assert bias is None bias = jnp.zeros(0, dtype=q.dtype) - return FusedAttnFwdPrimitive.outer_primitive.bind(q, - kv, - jnp.zeros(0, q.dtype), - bias, - q_seqlen, - kv_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - -def fused_attn_bwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, - softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, - doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, - attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, - scaling_factor: float, dropout_probability: float, is_training: bool): + return FusedAttnFwdPrimitive.outer_primitive.bind( + q, + kv, + jnp.zeros(0, q.dtype), + bias, + q_seqlen, + kv_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + ) + + +def fused_attn_bwd_kvpacked( + q: jnp.ndarray, + kv: jnp.ndarray, + bias: jnp.ndarray, + softmax_aux: jnp.ndarray, + rng_state: jnp.ndarray, + output: jnp.ndarray, + doutput: jnp.ndarray, + q_seqlen: jnp.ndarray, + kv_seqlen: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): """ Wrapper for TE fused attention bwd with kvpacked inputs Return the gradients of fused attention with packed kv input @@ -696,14 +989,25 @@ def fused_attn_bwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, scaling_factor=scaling_factor, dropout_probability=dropout_probability, - is_training=is_training) + is_training=is_training, + ) return dq, dkv, dbias -def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, - q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray, - attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, - scaling_factor: float, dropout_probability: float, is_training: bool): +def fused_attn_fwd( + q: jnp.ndarray, + k: jnp.ndarray, + v: jnp.ndarray, + bias: jnp.ndarray, + q_seqlen: jnp.ndarray, + kv_seqlen: jnp.ndarray, + seed: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): """ Wrapper for TE fused attention fwd, where query, key, value are seperated tensors Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 @@ -728,14 +1032,27 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, scaling_factor=scaling_factor, dropout_probability=dropout_probability, - is_training=is_training) - - -def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, - softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, - doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, - attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, - scaling_factor: float, dropout_probability: float, is_training: bool): + is_training=is_training, + ) + + +def fused_attn_bwd( + q: jnp.ndarray, + k: jnp.ndarray, + v: jnp.ndarray, + bias: jnp.ndarray, + softmax_aux: jnp.ndarray, + rng_state: jnp.ndarray, + output: jnp.ndarray, + doutput: jnp.ndarray, + q_seqlen: jnp.ndarray, + kv_seqlen: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, + dropout_probability: float, + is_training: bool, +): """ Wrapper for TE fused attention bwd Return the gradients of fused attention with seperated query, key, value tensors @@ -759,4 +1076,5 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, scaling_factor=scaling_factor, dropout_probability=dropout_probability, - is_training=is_training) + is_training=is_training, + ) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index cc59a6fdc165068c1e92d31ee24721f6742a78da..88fab695d6e2048694786aec555107e42d372288 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -72,6 +72,7 @@ class BasePrimitive(metaclass=ABCMeta): """ return NotImplemented + def register_primitive(cls): """ register jax primitive @@ -85,7 +86,7 @@ def register_primitive(cls): inner_p.multiple_results = cls.multiple_results inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.def_abstract_eval(cls.abstract) - mlir.register_lowering(inner_p, cls.lowering, platform='cuda') + mlir.register_lowering(inner_p, cls.lowering, platform="cuda") cls.inner_primitive = inner_p outer_p = core.Primitive(name_of_wrapper_p()) @@ -95,8 +96,10 @@ def register_primitive(cls): outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) - outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands, - partition=cls.partition) - mlir.register_lowering(outer_p, - mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)) + outer_p_lower.def_partition( + infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition + ) + mlir.register_lowering( + outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) + ) cls.outer_primitive = outer_p diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 5bf52ddad57f02e39315eeae17e2dd8d0f36af94..36396a977c62ce43296b36176a681297874e3ea8 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -27,19 +27,23 @@ class CustomCallArgsWrapper: wrapper of XLA custom call args """ - def __init__(self, - output_types, - operands, - operand_shapes, - operand_specific_layouts=None, - output_specific_layouts=None): + def __init__( + self, + output_types, + operands, + operand_shapes, + operand_specific_layouts=None, + output_specific_layouts=None, + ): self.output_types = output_types self.operands = operands - self.operand_layouts = CustomCallArgsWrapper.generate_layouts(operand_shapes, - operand_specific_layouts) + self.operand_layouts = CustomCallArgsWrapper.generate_layouts( + operand_shapes, operand_specific_layouts + ) output_shapes = [x.shape for x in output_types] - self.output_layouts = CustomCallArgsWrapper.generate_layouts(output_shapes, - output_specific_layouts) + self.output_layouts = CustomCallArgsWrapper.generate_layouts( + output_shapes, output_specific_layouts + ) @staticmethod def generate_layouts(shapes, specific_layouts): @@ -67,19 +71,21 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): XLA custom call warpper """ if hasattr(mlir, "custom_call"): - out = mlir.custom_call(name, - result_types=args.output_types, - operands=args.operands, - operand_layouts=args.operand_layouts, - result_layouts=args.output_layouts, - backend_config=opaque, - has_side_effect=has_side_effect, - **kwargs).results + out = mlir.custom_call( + name, + result_types=args.output_types, + operands=args.operands, + operand_layouts=args.operand_layouts, + result_layouts=args.output_layouts, + backend_config=opaque, + has_side_effect=has_side_effect, + **kwargs + ).results else: # Need to disable one pylint error as the second function # parameter name recenctly in JAX. Otherwise we won't be # compatible with multiple JAX version. - out = custom_call( # pylint: disable=too-many-function-args + out = custom_call( # pylint: disable=too-many-function-args name, args.output_types, operands=args.operands, @@ -87,5 +93,6 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): result_layouts=args.output_layouts, backend_config=opaque, has_side_effect=has_side_effect, - **kwargs) + **kwargs + ) return out diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index e5b760faa32906dff91e7944ef5bce926bbe8a07..b27e97d7b57ebffe8918005d32dd8eb674168ea4 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -27,7 +27,7 @@ def te_dtype_to_jax_dtype(te_dtype): TEDType.kInt64: jnp.int64, TEDType.kFloat8E4M3: jnp.float8_e4m3fn, TEDType.kFloat8E5M2: jnp.float8_e5m2, - TEDType.kByte: jnp.uint8 + TEDType.kByte: jnp.uint8, } if te_dtype not in converter: @@ -88,12 +88,11 @@ def check_valid_batch_dims(bdims): Assert out non-supported bath dims """ for dim in bdims: - assert dim in [0, None], \ - "Currently only support batch_dim in [0, None], " \ - f"but got {dim=}" + assert dim in [0, None], f"Currently only support batch_dim in [0, None], but got {dim=}" + def normalize_axis_boundary(axis, ndim): - """ NA """ + """NA""" return axis if axis >= 0 else ndim + axis @@ -119,10 +118,13 @@ def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): Xt = (dim0, dim3, dim4, dim1. dim2) """ if static_axis_boundary < 0: - static_axis_boundary = -1 # means no static axes - assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. + static_axis_boundary = -1 # means no static axes + assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. transpose_start_idx = static_axis_boundary + 1 transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape)) assert transpose_start_idx < transpose_axis_boundary - return (*shape[:transpose_start_idx], *shape[transpose_axis_boundary:], - *shape[transpose_start_idx:transpose_axis_boundary]) + return ( + *shape[:transpose_start_idx], + *shape[transpose_axis_boundary:], + *shape[transpose_start_idx:transpose_axis_boundary], + ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 6ee3a90675ab7b6778163ca79180509aa879e7bf..59468db0daaa171965be8552a3bf58022877e078 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -23,27 +23,29 @@ from .misc import ( check_valid_batch_dims, jax_dtype_to_te_dtype, jax_dtype_to_ir_dtype, - te_dtype_to_jax_dtype + te_dtype_to_jax_dtype, ) -from ..sharding import (all_reduce_max_along_all_axes_except_PP, - all_reduce_sum_along_dp_fsdp) +from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -__all__ = ['layernorm_fwd', - 'layernorm_bwd', - 'rmsnorm_fwd', - 'rmsnorm_bwd', - 'layernorm_fwd_fp8', - 'rmsnorm_fwd_fp8', - ] +__all__ = [ + "layernorm_fwd", + "layernorm_bwd", + "rmsnorm_fwd", + "rmsnorm_bwd", + "layernorm_fwd_fp8", + "rmsnorm_fwd_fp8", +] + class LayerNormFwdPrimitive(BasePrimitive): """ Layer Normalization Forward Primitive """ + name = "te_layernorm_forward" multiple_results = True - impl_static_args = (3, 4) # zero_centered_gamma, epsilon + impl_static_args = (3, 4) # zero_centered_gamma, epsilon inner_primitive = None outer_primitive = None @@ -65,18 +67,21 @@ class LayerNormFwdPrimitive(BasePrimitive): assert x_aval.size % hidden_size == 0 wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( - x_aval.size // hidden_size, # batch size + x_aval.size // hidden_size, # batch size hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16) + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16) True, - kwargs['zero_centered_gamma'], - kwargs['epsilon']) - wkspace_aval = out_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - barrier_aval = out_aval.update(shape=barrier_info[0], - dtype=te_dtype_to_jax_dtype(barrier_info[1])) + kwargs["zero_centered_gamma"], + kwargs["epsilon"], + ) + wkspace_aval = out_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + barrier_aval = out_aval.update( + shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) + ) return out_aval, mu_aval, rsigma_aval, wkspace_aval, barrier_aval @@ -85,8 +90,7 @@ class LayerNormFwdPrimitive(BasePrimitive): """ LayerNorm fwd outer primitive abstract """ - out_aval, mu_aval, rsigma_aval, _, _ = \ - LayerNormFwdPrimitive.abstract(*args, **kwargs) + out_aval, mu_aval, rsigma_aval, _, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs) return out_aval, mu_aval, rsigma_aval @staticmethod @@ -124,7 +128,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)) + ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), ] operands = [x, gamma, beta] operand_shapes = [x_shape, g_shape, b_shape] @@ -137,14 +141,14 @@ class LayerNormFwdPrimitive(BasePrimitive): hidden_size, wkspace_aval.size, barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, @@ -161,7 +165,8 @@ class LayerNormFwdPrimitive(BasePrimitive): """ assert LayerNormFwdPrimitive.inner_primitive is not None out, mu, rsigma, _, _ = LayerNormFwdPrimitive.inner_primitive.bind( - x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) + x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ) return out, mu, rsigma @staticmethod @@ -175,11 +180,12 @@ class LayerNormFwdPrimitive(BasePrimitive): x_bdim, _, _ = batch_dims out_bdims = x_bdim, x_bdim, x_bdim - return LayerNormFwdPrimitive.outer_primitive.bind(x, - gamma, - beta, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon), out_bdims + return ( + LayerNormFwdPrimitive.outer_primitive.bind( + x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ), + out_bdims, + ) @staticmethod def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): @@ -187,9 +193,9 @@ class LayerNormFwdPrimitive(BasePrimitive): x_spec = get_padded_spec(arg_infos[0]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) @@ -201,19 +207,19 @@ class LayerNormFwdPrimitive(BasePrimitive): x_spec, g_spec, b_spec = map(get_padded_spec, arg_infos) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) if g_spec[-1] is not None: warnings.warn( - f"{LayerNormFwdPrimitive.name} does not support sharding of parameter gamma " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{LayerNormFwdPrimitive.name} does not support sharding of parameter gamma " + "Enforcing no sharding of parameters hidden dim! " ) if b_spec[-1] is not None: warnings.warn( - f"{LayerNormFwdPrimitive.name} does not support sharding of parameter beta " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{LayerNormFwdPrimitive.name} does not support sharding of parameter beta " + "Enforcing no sharding of parameters hidden dim! " ) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) @@ -224,34 +230,34 @@ class LayerNormFwdPrimitive(BasePrimitive): arg_shardings = (x_sharding, g_sharding, b_sharding) out_shardings = (out_sharding, mu_sharding, rsigma_sharding) - impl = partial(LayerNormFwdPrimitive.impl, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + impl = partial( + LayerNormFwdPrimitive.impl, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ) return mesh, impl, out_shardings, arg_shardings register_primitive(LayerNormFwdPrimitive) -def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, - epsilon: float): +def layernorm_fwd( + x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float +): """ Wrapper for TE layernorm fwd """ - return LayerNormFwdPrimitive.outer_primitive.bind(x, - gamma, - beta, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + return LayerNormFwdPrimitive.outer_primitive.bind( + x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ) class LayerNormBwdPrimitive(BasePrimitive): """ Layer Normalization Backward Primitive """ + name = "te_layernorm_backward" multiple_results = True - impl_static_args = (5, 6) # zero_centered_gamma, epsilon + impl_static_args = (5, 6) # zero_centered_gamma, epsilon inner_primitive = None outer_primitive = None @@ -272,33 +278,48 @@ class LayerNormBwdPrimitive(BasePrimitive): dx_aval = core.raise_to_shaped(dz_aval) dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) - wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = \ + wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = ( transformer_engine_jax.get_layernorm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - True, kwargs['zero_centered_gamma'], kwargs['epsilon'] + True, + kwargs["zero_centered_gamma"], + kwargs["epsilon"], ) - wkspace_aval = dx_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - barrier_aval = dx_aval.update(shape=barrier_info[0], - dtype=te_dtype_to_jax_dtype(barrier_info[1])) - dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0], - dtype=te_dtype_to_jax_dtype(dgamma_part_info[1])) - dbeta_part_aval = dbeta_aval.update(shape=dbeta_part_info[0], - dtype=te_dtype_to_jax_dtype(dbeta_part_info[1])) - - return dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, barrier_aval, \ - dgamma_part_aval, dbeta_part_aval + ) + wkspace_aval = dx_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + barrier_aval = dx_aval.update( + shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) + ) + dgamma_part_aval = dgamma_aval.update( + shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]) + ) + dbeta_part_aval = dbeta_aval.update( + shape=dbeta_part_info[0], dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]) + ) + + return ( + dx_aval, + dgamma_aval, + dbeta_aval, + wkspace_aval, + barrier_aval, + dgamma_part_aval, + dbeta_part_aval, + ) @staticmethod def outer_abstract(*args, **kwargs): """ LayerNorm bwd outer primitive abstract """ - dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = \ - LayerNormBwdPrimitive.abstract(*args, **kwargs) + dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = LayerNormBwdPrimitive.abstract( + *args, **kwargs + ) return dx_aval, dgamma_aval, dbeta_aval @staticmethod @@ -361,7 +382,8 @@ class LayerNormBwdPrimitive(BasePrimitive): def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon): assert LayerNormBwdPrimitive.inner_primitive is not None dx, dgamma, dbeta, _, _, _, _ = LayerNormBwdPrimitive.inner_primitive.bind( - dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) + dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ) return dx, dgamma, dbeta @staticmethod @@ -372,13 +394,12 @@ class LayerNormBwdPrimitive(BasePrimitive): _, x_bdim, _, _, gamma_bdim = batch_dims out_bdims = x_bdim, gamma_bdim, gamma_bdim - return LayerNormBwdPrimitive.outer_primitive.bind(dz, - x, - mu, - rsigma, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon), out_bdims + return ( + LayerNormBwdPrimitive.outer_primitive.bind( + dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ), + out_bdims, + ) @staticmethod def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): @@ -386,16 +407,16 @@ class LayerNormBwdPrimitive(BasePrimitive): x_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) g_b_spec = get_padded_spec(arg_infos[4]) if g_b_spec[-1] is not None: warnings.warn( - f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \ - f"of gamma and beta of Layernorm " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " + "of gamma and beta of Layernorm " + "Enforcing no sharding of parameters hidden dim! " ) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) @@ -408,30 +429,29 @@ class LayerNormBwdPrimitive(BasePrimitive): x_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) g_b_spec = get_padded_spec(arg_infos[4]) if g_b_spec[-1] is not None: warnings.warn( - f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \ - f"of gamma and beta of Layernorm " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " + "of gamma and beta of Layernorm " + "Enforcing no sharding of parameters hidden dim! " ) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None)) out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding - x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. + x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2 arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(None))) def sharded_impl(dz, x, mu, rsigma, gamma): - local_dx, local_dgamma, local_dbeta = \ - LayerNormBwdPrimitive.impl(dz, x, mu, rsigma, gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl( + dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ) global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta) return local_dx, global_dgamma, global_dbeta @@ -442,27 +462,31 @@ class LayerNormBwdPrimitive(BasePrimitive): register_primitive(LayerNormBwdPrimitive) -def layernorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray, - gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float): +def layernorm_bwd( + dz: jnp.ndarray, + x: jnp.ndarray, + mu: jnp.ndarray, + rsigma: jnp.ndarray, + gamma: jnp.ndarray, + zero_centered_gamma: bool, + epsilon: float, +): """ Wrapper for TE layernorm bwd """ - return LayerNormBwdPrimitive.outer_primitive.bind(dz, - x, - mu, - rsigma, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + return LayerNormBwdPrimitive.outer_primitive.bind( + dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ) class RmsNormFwdPrimitive(BasePrimitive): """ RMS Normalization Forward Primitive """ + name = "te_rmsnorm_forward" multiple_results = True - impl_static_args = (2,) # epsilon + impl_static_args = (2,) # epsilon inner_primitive = None outer_primitive = None @@ -483,18 +507,21 @@ class RmsNormFwdPrimitive(BasePrimitive): assert x_aval.size % hidden_size == 0 wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( - x_aval.size // hidden_size, # batch size + x_aval.size // hidden_size, # batch size hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16) + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16) False, False, - kwargs['epsilon']) - wkspace_aval = out_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - barrier_aval = out_aval.update(shape=barrier_info[0], - dtype=te_dtype_to_jax_dtype(barrier_info[1])) + kwargs["epsilon"], + ) + wkspace_aval = out_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + barrier_aval = out_aval.update( + shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) + ) return out_aval, rsigma_aval, wkspace_aval, barrier_aval @@ -529,7 +556,7 @@ class RmsNormFwdPrimitive(BasePrimitive): ir.RankedTensorType.get(out_shape, x_type.element_type), ir.RankedTensorType.get(batch_shape, rsigma_element_type), ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)) + ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), ] operands = [x, gamma] operand_shapes = [x_shape, g_shape] @@ -542,15 +569,15 @@ class RmsNormFwdPrimitive(BasePrimitive): hidden_size, wkspace_aval.size, barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype - False, # RMSNorm doesn't support zero_centered_gamma + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype + False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, ) @@ -587,9 +614,9 @@ class RmsNormFwdPrimitive(BasePrimitive): x_spec = get_padded_spec(arg_infos[0]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) @@ -601,14 +628,14 @@ class RmsNormFwdPrimitive(BasePrimitive): x_spec, g_spec = map(get_padded_spec, arg_infos) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) if g_spec[-1] is not None: warnings.warn( - f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma " + "Enforcing no sharding of parameters hidden dim! " ) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) @@ -635,9 +662,10 @@ class RmsNormBwdPrimitive(BasePrimitive): """ RMS Normalization Backward Primitive """ + name = "te_rmsnorm_backward" multiple_results = True - impl_static_args = (4,) # epsilon + impl_static_args = (4,) # epsilon inner_primitive = None outer_primitive = None @@ -657,20 +685,26 @@ class RmsNormBwdPrimitive(BasePrimitive): dx_aval = core.raise_to_shaped(dz_aval) dgamma_aval = core.raise_to_shaped(gamma_aval) - wkspace_info, barrier_info, dgamma_part_info, _ = \ + wkspace_info, barrier_info, dgamma_part_info, _ = ( transformer_engine_jax.get_layernorm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - False, False, kwargs['epsilon'] + False, + False, + kwargs["epsilon"], ) - wkspace_aval = dx_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - barrier_aval = dx_aval.update(shape=barrier_info[0], - dtype=te_dtype_to_jax_dtype(barrier_info[1])) - dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0], - dtype=te_dtype_to_jax_dtype(dgamma_part_info[1])) + ) + wkspace_aval = dx_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + barrier_aval = dx_aval.update( + shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) + ) + dgamma_part_aval = dgamma_aval.update( + shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]) + ) return dx_aval, dgamma_aval, wkspace_aval, barrier_aval, dgamma_part_aval @@ -705,8 +739,9 @@ class RmsNormBwdPrimitive(BasePrimitive): ir.RankedTensorType.get(g_shape, g_type.element_type), ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), - ir.RankedTensorType.get(dgamma_part_aval.shape, - jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)) + ir.RankedTensorType.get( + dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype) + ), ] operands = [dz, rsigma, x, gamma] operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] @@ -720,14 +755,14 @@ class RmsNormBwdPrimitive(BasePrimitive): wkspace_aval.size, barrier_aval.size, dgamma_part_aval.shape, - (0,), # no dbeta_part for RMSnorm + (0,), # no dbeta_part for RMSnorm jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype), jax_dtype_to_te_dtype(dgamma_part_aval.dtype), - TEDType.kByte, # dummy dbeta_part te_dtype - False, # RMSNorm doesn't support zero_centered_gamma + TEDType.kByte, # dummy dbeta_part te_dtype + False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, ) @@ -739,8 +774,9 @@ class RmsNormBwdPrimitive(BasePrimitive): @staticmethod def impl(dz, x, rsigma, gamma, epsilon): assert RmsNormBwdPrimitive.inner_primitive is not None - dx, dgamma, _, _, _ = \ - RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) + dx, dgamma, _, _, _ = RmsNormBwdPrimitive.inner_primitive.bind( + dz, x, rsigma, gamma, epsilon=epsilon + ) return dx, dgamma @staticmethod @@ -751,8 +787,10 @@ class RmsNormBwdPrimitive(BasePrimitive): _, x_bdim, _, gamma_bdim = batch_dims out_bdims = x_bdim, gamma_bdim - return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, - epsilon=epsilon), out_bdims + return ( + RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon), + out_bdims, + ) @staticmethod def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos): @@ -760,15 +798,15 @@ class RmsNormBwdPrimitive(BasePrimitive): x_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) g_spec = get_padded_spec(arg_infos[3]) if g_spec[-1] is not None: warnings.warn( - f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " + "Enforcing no sharding of parameters hidden dim! " ) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) dgamma_sharding = NamedSharding(mesh, PartitionSpec(None)) @@ -780,26 +818,25 @@ class RmsNormBwdPrimitive(BasePrimitive): x_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) g_spec = get_padded_spec(arg_infos[3]) if g_spec[-1] is not None: warnings.warn( - f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " + "Enforcing no sharding of parameters hidden dim! " ) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) dgamma_sharding = NamedSharding(mesh, PartitionSpec(None)) out_shardings = dx_sharding, dgamma_sharding - x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. + x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(None))) def sharded_impl(dz, x, rsigma, gamma): - local_dx, local_dgamma = \ - RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon) + local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon) global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) return local_dx, global_dgamma @@ -809,8 +846,9 @@ class RmsNormBwdPrimitive(BasePrimitive): register_primitive(RmsNormBwdPrimitive) -def rmsnorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray, - epsilon: float): +def rmsnorm_bwd( + dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray, epsilon: float +): """ Wrapper for TE layernorm bwd """ @@ -821,15 +859,26 @@ class LayerNormFwdFp8Primitive(BasePrimitive): """ Layer Normalization Forward FP8 Primitive """ + name = "te_layernorm_forward_fp8" multiple_results = True - impl_static_args = (6, 7, 8) # out_type, zero_centered_gamma, epsilon + impl_static_args = (6, 7, 8) # out_type, zero_centered_gamma, epsilon inner_primitive = None outer_primitive = None @staticmethod - def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, - zero_centered_gamma, epsilon): + def abstract( + x_aval, + gamma_aval, + beta_aval, + amax_aval, + scale_aval, + scale_inv_aval, + *, + out_dtype, + zero_centered_gamma, + epsilon, + ): """ LayerNorm fwd (fp8 out) inner primitive abstract """ @@ -845,22 +894,25 @@ class LayerNormFwdFp8Primitive(BasePrimitive): assert gamma_aval.size == beta_aval.size wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # in type - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # in type + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type jax_dtype_to_te_dtype(out_dtype), True, zero_centered_gamma, - epsilon) + epsilon, + ) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - wkspace_aval = x_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - barrier_aval = x_aval.update(shape=barrier_info[0], - dtype=te_dtype_to_jax_dtype(barrier_info[1])) + wkspace_aval = x_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + barrier_aval = x_aval.update( + shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) + ) return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval, barrier_aval @@ -869,13 +921,15 @@ class LayerNormFwdFp8Primitive(BasePrimitive): """ LayerNorm fwd (fp8 out) outer primitive abstract """ - out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = \ - LayerNormFwdFp8Primitive.abstract(*args, **kwargs) + out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = LayerNormFwdFp8Primitive.abstract( + *args, **kwargs + ) return out_aval, mu_aval, rsigma_aval, updated_amax_aval @staticmethod - def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma, - epsilon): + def lowering( + ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma, epsilon + ): """ LayerNorm fwd (fp8 out) lowering rules """ @@ -922,11 +976,16 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)) + ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), ] operands = [x, gamma, beta, amax, scale, scale_inv] operand_shapes = [ - x_shape, g_shape, b_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape + x_shape, + g_shape, + b_shape, + ir_amax_shape, + ir_scale_shape, + ir_scale_inv_shape, ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) @@ -937,24 +996,22 @@ class LayerNormFwdFp8Primitive(BasePrimitive): hidden_size, wkspace_aval.size, barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, ) - out = custom_caller(LayerNormFwdFp8Primitive.name, - args, - opaque, - False, - operand_output_aliases={3: 3}) + out = custom_caller( + LayerNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={3: 3} + ) return out @@ -973,7 +1030,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive): scale_inv, out_dtype=out_dtype, zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + epsilon=epsilon, + ) return out, mu, rsigma, updated_amax @staticmethod @@ -987,27 +1045,33 @@ class LayerNormFwdFp8Primitive(BasePrimitive): x_bdim, _, _, amax_bdim, _, _ = batch_dims out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim - return LayerNormFwdFp8Primitive.outer_primitive.bind( - x, - gamma, - beta, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon), out_bdims + return ( + LayerNormFwdFp8Primitive.outer_primitive.bind( + x, + gamma, + beta, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, + ), + out_bdims, + ) @staticmethod - def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, - result_infos): + def infer_sharding_from_operands( + out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos + ): del out_dtype, zero_centered_gamma, epsilon, result_infos x_spec = get_padded_spec(arg_infos[0]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance.") + f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." + ) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) @@ -1022,37 +1086,44 @@ class LayerNormFwdFp8Primitive(BasePrimitive): b_spec = get_padded_spec(arg_infos[2]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {LayerNormFwdFp8Primitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {LayerNormFwdFp8Primitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) if g_spec[-1] is not None: warnings.warn( - f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter gamma " + "Enforcing no sharding of parameters hidden dim! " ) if b_spec[-1] is not None: warnings.warn( - f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter beta " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter beta " + "Enforcing no sharding of parameters hidden dim! " ) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) g_sharding = NamedSharding(mesh, PartitionSpec(None)) b_sharding = NamedSharding(mesh, PartitionSpec(None)) out_sharding = x_sharding mu_sharding = rsigma_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1])) + mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]) + ) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3]))) fp8_meta_sharding = amax_sharding arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3 out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding) def sharded_impl(x, gamma, beta, amax, scale, scale_inv): - local_x, local_mu, local_rsigma, local_amax = \ - LayerNormFwdFp8Primitive.impl(x, gamma, beta, amax, scale, scale_inv, - out_dtype=out_dtype, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + local_x, local_mu, local_rsigma, local_amax = LayerNormFwdFp8Primitive.impl( + x, + gamma, + beta, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, + ) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) return local_x, local_mu, local_rsigma, global_updated_amax @@ -1063,30 +1134,41 @@ class LayerNormFwdFp8Primitive(BasePrimitive): register_primitive(LayerNormFwdFp8Primitive) -def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray, - scale: jnp.ndarray, scale_inv: jnp.ndarray, out_dtype: jnp.dtype, - zero_centered_gamma: bool, epsilon: float): +def layernorm_fwd_fp8( + x: jnp.ndarray, + gamma: jnp.ndarray, + beta: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: jnp.dtype, + zero_centered_gamma: bool, + epsilon: float, +): """ Wrapper for TE layernorm fwd (fp8 out) """ - return LayerNormFwdFp8Primitive.outer_primitive.bind(x, - gamma, - beta, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + return LayerNormFwdFp8Primitive.outer_primitive.bind( + x, + gamma, + beta, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, + ) class RmsNormFwdFp8Primitive(BasePrimitive): """ RMS Normalization Forward FP8 Primitive """ + name = "te_rmsnorm_forward_fp8" multiple_results = True - impl_static_args = (5, 6) # out_dtype, epsilon + impl_static_args = (5, 6) # out_dtype, epsilon inner_primitive = None outer_primitive = None @@ -1108,22 +1190,25 @@ class RmsNormFwdFp8Primitive(BasePrimitive): rsigama_dtype = jnp.float32 wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( - x_aval.size // hidden_size, # batch_size + x_aval.size // hidden_size, # batch_size hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - jax_dtype_to_te_dtype(out_dtype), # out te_dtype + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + jax_dtype_to_te_dtype(out_dtype), # out te_dtype False, False, - epsilon) + epsilon, + ) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - wkspace_aval = x_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - barrier_aval = x_aval.update(shape=barrier_info[0], - dtype=te_dtype_to_jax_dtype(barrier_info[1])) + wkspace_aval = x_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + barrier_aval = x_aval.update( + shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) + ) return out_aval, rsigma_aval, amax_aval, wkspace_aval, barrier_aval @@ -1176,7 +1261,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)) + ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), ] operands = [x, gamma, amax, scale, scale_inv] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] @@ -1189,24 +1274,22 @@ class RmsNormFwdFp8Primitive(BasePrimitive): hidden_size, wkspace_aval.size, barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype - False, # RMSNorm doesn't support zero_centered_gamma + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype + False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, ) - out = custom_caller(RmsNormFwdFp8Primitive.name, - args, - opaque, - False, - operand_output_aliases={2: 2}) + out = custom_caller( + RmsNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={2: 2} + ) return out @@ -1216,13 +1299,9 @@ class RmsNormFwdFp8Primitive(BasePrimitive): to describe implementation """ assert RmsNormFwdFp8Primitive.inner_primitive is not None - out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind(x, - gamma, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - epsilon=epsilon) + out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind( + x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon + ) return out, rsigma, amax @staticmethod @@ -1235,13 +1314,12 @@ class RmsNormFwdFp8Primitive(BasePrimitive): x, gamma, amax, scale, scale_inv = batched_args x_bdim, _, amax_bdim, _, _ = batch_dims out_bdims = x_bdim, x_bdim, amax_bdim - return RmsNormFwdFp8Primitive.outer_primitive.bind(x, - gamma, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - epsilon=epsilon), out_bdims + return ( + RmsNormFwdFp8Primitive.outer_primitive.bind( + x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon + ), + out_bdims, + ) @staticmethod def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_infos): @@ -1249,9 +1327,9 @@ class RmsNormFwdFp8Primitive(BasePrimitive): x_spec = get_padded_spec(arg_infos[0]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) @@ -1265,14 +1343,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive): g_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \ - f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ - f"and hurt performance." + f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " + "Force to not shard the hidden dim, which might introduce extra collective ops, " + "and hurt performance." ) if g_spec[-1] is not None: warnings.warn( - f"{RmsNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \ - f"Enforcing no sharding of parameters hidden dim! " \ + f"{RmsNormFwdFp8Primitive.name} does not support sharding of parameter gamma " + "Enforcing no sharding of parameters hidden dim! " ) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) g_sharding = NamedSharding(mesh, PartitionSpec(None)) @@ -1284,9 +1362,9 @@ class RmsNormFwdFp8Primitive(BasePrimitive): out_shardings = (out_sharding, rsigma_sharding, amax_sharding) def sharded_impl(x, gamma, amax, scale, scale_inv): - local_x, local_rsigma, local_amax= \ - RmsNormFwdFp8Primitive.impl(x, gamma, amax, scale, scale_inv, - out_dtype=out_dtype, epsilon=epsilon) + local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl( + x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon + ) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) return local_x, local_rsigma, global_updated_amax @@ -1297,15 +1375,18 @@ class RmsNormFwdFp8Primitive(BasePrimitive): register_primitive(RmsNormFwdFp8Primitive) -def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, - scale_inv: jnp.ndarray, out_dtype: jnp.dtype, epsilon: float): +def rmsnorm_fwd_fp8( + x: jnp.ndarray, + gamma: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: jnp.dtype, + epsilon: float, +): """ Wrapper for TE rmsnorm fwd (fp8 out) """ - return RmsNormFwdFp8Primitive.outer_primitive.bind(x, - gamma, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - epsilon=epsilon) + return RmsNormFwdFp8Primitive.outer_primitive.bind( + x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon + ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index e1345fe53bc6b9b373b366adb5091afe94498d48..40974b07b94f8e13bcd9f5f5c1a85ecb86ae1061 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -18,18 +18,19 @@ from .misc import ( get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, - jax_dtype_to_ir_dtype + jax_dtype_to_ir_dtype, ) from ..sharding import all_reduce_max_along_all_axes_except_PP -__all__ = ['cast_fp8'] +__all__ = ["cast_fp8"] class CastFP8Primitive(BasePrimitive): """ Cast Primitive """ + name = "te_quantize" multiple_results = True impl_static_args = (4,) @@ -79,15 +80,13 @@ class CastFP8Primitive(BasePrimitive): operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_common_descriptor(ir_x_shape, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype)) + opaque = transformer_engine_jax.pack_common_descriptor( + ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype) + ) - out = custom_caller(CastFP8Primitive.name, - args, - opaque, - False, - operand_output_aliases={1: 1}) + out = custom_caller( + CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} + ) return out @@ -97,9 +96,9 @@ class CastFP8Primitive(BasePrimitive): te_cast implementation """ assert CastFP8Primitive.inner_primitive is not None - casted_x, updated_amax = \ - CastFP8Primitive.inner_primitive.bind( - x, amax, scale, scale_inv, out_dtype=out_dtype) + casted_x, updated_amax = CastFP8Primitive.inner_primitive.bind( + x, amax, scale, scale_inv, out_dtype=out_dtype + ) return casted_x, updated_amax @staticmethod @@ -111,8 +110,10 @@ class CastFP8Primitive(BasePrimitive): x_bdim, amax_bdim, *_ = batch_dims out_bdims = x_bdim, amax_bdim - return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, - out_dtype=out_dtype), out_bdims + return ( + CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype), + out_bdims, + ) @staticmethod def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): @@ -132,8 +133,9 @@ class CastFP8Primitive(BasePrimitive): out_shardings = (casted_x_sharding, amax_sharding) def sharded_impl(x, amax, scale, scale_inv): - local_cx, local_updated_amax = \ - CastFP8Primitive.impl(x, amax, scale, scale_inv, out_dtype=out_dtype) + local_cx, local_updated_amax = CastFP8Primitive.impl( + x, amax, scale, scale_inv, out_dtype=out_dtype + ) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) return local_cx, global_updated_amax @@ -144,8 +146,13 @@ class CastFP8Primitive(BasePrimitive): register_primitive(CastFP8Primitive) -def cast_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray]: +def cast_fp8( + x: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: TEDType, +) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Cast wrapper Return FP8 tensor diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index 48e17fb50322ba63692da72103406a590335dc7d..6cc1218bfb93a7c894f8700281c12fae651f67de 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -16,36 +16,42 @@ from transformer_engine import transformer_engine_jax from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper -from .misc import ( - get_padded_spec, - check_valid_batch_dims, - jax_dtype_to_te_dtype -) +from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype from ..softmax import SoftmaxType -__all__ = ['scaled_softmax_fwd', - 'scaled_softmax_bwd', - 'scaled_masked_softmax_fwd', - 'scaled_masked_softmax_bwd', - 'scaled_upper_triang_masked_softmax_fwd', - 'scaled_upper_triang_masked_softmax_bwd', - 'is_softmax_kernel_available', - ] - - -def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: int, q_seqlen: int, - k_seqlen: int, dtype: jnp.dtype): +__all__ = [ + "scaled_softmax_fwd", + "scaled_softmax_bwd", + "scaled_masked_softmax_fwd", + "scaled_masked_softmax_bwd", + "scaled_upper_triang_masked_softmax_fwd", + "scaled_upper_triang_masked_softmax_bwd", + "is_softmax_kernel_available", +] + + +def is_softmax_kernel_available( + softmax_type: SoftmaxType, + batch: int, + heads: int, + q_seqlen: int, + k_seqlen: int, + dtype: jnp.dtype, +): """check softmax available""" if softmax_type is SoftmaxType.SCALED: - return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen, - dtype) + return ScaledSoftmaxFwdPrimitive.is_kernel_available( + batch, heads, q_seqlen, k_seqlen, dtype + ) if softmax_type is SoftmaxType.SCALED_MASKED: - return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen, - dtype) + return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available( + batch, heads, q_seqlen, k_seqlen, dtype + ) if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available( - batch, heads, q_seqlen, k_seqlen, dtype) + batch, heads, q_seqlen, k_seqlen, dtype + ) raise NotImplementedError @@ -54,13 +60,15 @@ class SoftmaxPrimitive(BasePrimitive): """ Softmax Primitive """ + max_k_seqlen_supported = 16384 name = "te_softmax_internal_placeholder" @staticmethod @abstractmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: + def is_kernel_available( + batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype + ) -> bool: """Check Softmax kernel availability based on size""" raise NotImplementedError @@ -68,7 +76,7 @@ class SoftmaxPrimitive(BasePrimitive): def get_batch_per_block(k_seqlen: int) -> int: """Get batch per CTA in Softmax kernels""" threads_per_warp = 32 - threads_per_block = 128 # Depends on the kernel implmentation + threads_per_block = 128 # Depends on the kernel implmentation pow2 = 1 << (k_seqlen - 1).bit_length() warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp @@ -100,7 +108,7 @@ class SoftmaxPrimitive(BasePrimitive): """ softmax_forward lowering rules """ - i_aval, = ctx.avals_in + (i_aval,) = ctx.avals_in i_type = ir.RankedTensorType(logits.type) i_shape = i_type.shape # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] @@ -115,10 +123,15 @@ class SoftmaxPrimitive(BasePrimitive): operand_shapes = [i_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(i_aval.dtype), - scale_factor) + opaque = transformer_engine_jax.pack_softmax_descriptor( + batch, + pad_batch, + heads, + q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(i_aval.dtype), + scale_factor, + ) out = custom_caller(name, args, opaque, False) @@ -139,8 +152,8 @@ class SoftmaxPrimitive(BasePrimitive): softmax_forward batcher """ assert primitive is not None - logits, = batched_args - logits_bdim, = batch_dims + (logits,) = batched_args + (logits_bdim,) = batch_dims out_bdims = logits_bdim return primitive.bind(logits, scale_factor=scale_factor), out_bdims @@ -150,13 +163,13 @@ class SoftmaxPrimitive(BasePrimitive): """ softmax_forward infer_sharding_from_operands """ - del scale_factor, result_infos # Unused. + del scale_factor, result_infos # Unused. logits_spec = get_padded_spec(arg_infos[0]) if logits_spec[-1] is not None: warnings.warn( - f"Sharding the hidden dimension is not supported in {cls.name}! " \ - f"Forcing XLA to not shard the hidden dim, which might introduce extra " \ - f"collective ops and hurt performance." + f"Sharding the hidden dimension is not supported in {cls.name}! " + "Forcing XLA to not shard the hidden dim, which might introduce extra " + "collective ops and hurt performance." ) out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) return out_sharding @@ -170,9 +183,9 @@ class SoftmaxPrimitive(BasePrimitive): logits_spec = get_padded_spec(arg_infos[0]) if logits_spec[-1] is not None: warnings.warn( - f"Sharding the hidden dimension is not supported in {cls.name}! " \ - f"Forcing XLA to not shard the hidden dim, which might introduce extra " \ - f"collective ops and hurt performance." + f"Sharding the hidden dimension is not supported in {cls.name}! " + "Forcing XLA to not shard the hidden dim, which might introduce extra " + "collective ops and hurt performance." ) out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) arg_shardings = (out_shardings,) @@ -180,7 +193,9 @@ class SoftmaxPrimitive(BasePrimitive): return mesh, impl, out_shardings, arg_shardings @staticmethod - def backward_abstract(dz_aval, softmax_out_aval, scale_factor=None): # pylint: disable=unused-argument + def backward_abstract( + dz_aval, softmax_out_aval, scale_factor=None + ): # pylint: disable=unused-argument """ softmax_backward abstract """ @@ -207,7 +222,7 @@ class SoftmaxPrimitive(BasePrimitive): # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] batch = reduce(operator.mul, dz_shape[:-3]) - pad_batch = batch # unused + pad_batch = batch # unused heads = dz_shape[-3] q_seqlen = dz_shape[-2] k_seqlen = dz_shape[-1] @@ -221,8 +236,14 @@ class SoftmaxPrimitive(BasePrimitive): args = CustomCallArgsWrapper(out_types, operands, operand_shapes) opaque = transformer_engine_jax.pack_softmax_descriptor( - batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(dz_aval.dtype), - scale_factor) + batch, + pad_batch, + heads, + q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(dz_aval.dtype), + scale_factor, + ) out = custom_caller(name, args, opaque, False) @@ -254,13 +275,13 @@ class SoftmaxPrimitive(BasePrimitive): """ softmax_backward infer_sharding_from_operands """ - del scale_factor, result_infos # Unused. + del scale_factor, result_infos # Unused. dz_spec = get_padded_spec(arg_infos[0]) if dz_spec[-1] is not None: warnings.warn( - f"Sharding the hidden dimension is not supported in {cls.name}! " \ - f"Forcing XLA to not shard the hidden dim, which might introduce extra " \ - f"collective ops and hurt performance." + f"Sharding the hidden dimension is not supported in {cls.name}! " + "Forcing XLA to not shard the hidden dim, which might introduce extra " + "collective ops and hurt performance." ) dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None)) return dx_sharding @@ -276,9 +297,9 @@ class SoftmaxPrimitive(BasePrimitive): softmax_out_spec = get_padded_spec(arg_infos[1]) if dz_spec[-1] is not None or softmax_out_spec[-1] is not None: warnings.warn( - f"Sharding the hidden dimension is not supported in {cls.name}! " \ - f"Forcing XLA to not shard the hidden dim, which might introduce extra " \ - f"collective ops and hurt performance." + f"Sharding the hidden dimension is not supported in {cls.name}! " + "Forcing XLA to not shard the hidden dim, which might introduce extra " + "collective ops and hurt performance." ) dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None)) @@ -295,31 +316,34 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): """ Scaled Softmax Fwd Primitive """ + name = "te_scaled_softmax_forward" multiple_results = False - impl_static_args = (1,) # scale_factor + impl_static_args = (1,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: + def is_kernel_available( + batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype + ) -> bool: """Check Softmax kernel availability based on size""" attn_batches = batch * heads dtype = dtypes.canonicalize_dtype(dtype) - if (dtype in [jnp.float16, jnp.bfloat16] - and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 - and attn_batches % 4 == 0 # batch * heads must be divisor of 4 - ): + if ( + dtype in [jnp.float16, jnp.bfloat16] + and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported + and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 + and attn_batches % 4 == 0 # batch * heads must be divisor of 4 + ): if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) return q_seqlen % batch_per_block == 0 return False @staticmethod - def abstract(logits_aval, scale_factor): # pylint: disable=unused-argument + def abstract(logits_aval, scale_factor): # pylint: disable=unused-argument """ te_scaled_softmax_forward abstract """ @@ -330,34 +354,37 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): """ te_scaled_softmax_forward lowering rules """ - return SoftmaxPrimitive.forward_lowering(ScaledSoftmaxFwdPrimitive.name, - ctx, - logits, - scale_factor=scale_factor) + return SoftmaxPrimitive.forward_lowering( + ScaledSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor + ) @staticmethod def impl(logits, scale_factor): - return SoftmaxPrimitive.forward_impl(ScaledSoftmaxFwdPrimitive.inner_primitive, logits, - scale_factor) + return SoftmaxPrimitive.forward_impl( + ScaledSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor + ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): check_valid_batch_dims(batch_dims) - return SoftmaxPrimitive.forward_batcher(ScaledSoftmaxFwdPrimitive.outer_primitive, - batched_args, - batch_dims, - scale_factor=scale_factor) + return SoftmaxPrimitive.forward_batcher( + ScaledSoftmaxFwdPrimitive.outer_primitive, + batched_args, + batch_dims, + scale_factor=scale_factor, + ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos) + scale_factor, mesh, arg_infos, result_infos + ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): - return ScaledSoftmaxFwdPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, - scale_factor, mesh, arg_infos, - result_infos) + return ScaledSoftmaxFwdPrimitive.forward_partition( + ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos + ) register_primitive(ScaledSoftmaxFwdPrimitive) @@ -375,18 +402,21 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): """ Scaled Softmax Bwd Primitive """ + name = "te_scaled_softmax_backward" multiple_results = False - impl_static_args = (2,) # scale_factor + impl_static_args = (2,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: + def is_kernel_available( + batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype + ) -> bool: """Check Softmax kernel availability based on size""" - return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen, - dtype) + return ScaledSoftmaxFwdPrimitive.is_kernel_available( + batch, heads, q_seqlen, k_seqlen, dtype + ) @staticmethod def abstract(dz_aval, softmax_out_aval, scale_factor): @@ -400,84 +430,88 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): """ te_scaled_softmax_backward lowering rules """ - out = SoftmaxPrimitive.backward_lowering(ScaledSoftmaxBwdPrimitive.name, - ctx, - dz, - softmax_out, - scale_factor=scale_factor) + out = SoftmaxPrimitive.backward_lowering( + ScaledSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor + ) return out @staticmethod def impl(dz, softmax_out, scale_factor): - return SoftmaxPrimitive.backward_impl(ScaledSoftmaxBwdPrimitive.inner_primitive, - dz, - softmax_out, - scale_factor=scale_factor) + return SoftmaxPrimitive.backward_impl( + ScaledSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor + ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): check_valid_batch_dims(batch_dims) - return SoftmaxPrimitive.backward_batcher(ScaledSoftmaxBwdPrimitive.outer_primitive, - batched_args, - batch_dims, - scale_factor=scale_factor) + return SoftmaxPrimitive.backward_batcher( + ScaledSoftmaxBwdPrimitive.outer_primitive, + batched_args, + batch_dims, + scale_factor=scale_factor, + ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos) + scale_factor, mesh, arg_infos, result_infos + ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): - return ScaledSoftmaxBwdPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, - scale_factor, mesh, arg_infos, - result_infos) + return ScaledSoftmaxBwdPrimitive.backward_partition( + ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos + ) register_primitive(ScaledSoftmaxBwdPrimitive) -def scaled_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: +def scaled_softmax_bwd( + dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float +) -> jnp.ndarray: """ scaled_backward wrapper Return FP16/BF16 tensor """ - return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(dz, - softmax_out, - scale_factor=scale_factor) + return ScaledSoftmaxBwdPrimitive.outer_primitive.bind( + dz, softmax_out, scale_factor=scale_factor + ) class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): """ Scaled Masked Softmax Fwd Primitive """ + name = "te_scaled_masked_softmax_forward" multiple_results = False - impl_static_args = (2,) # scale_factor + impl_static_args = (2,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: + def is_kernel_available( + batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype + ) -> bool: """Check Softmax kernel availability based on size""" attn_batches = batch * heads dtype = dtypes.canonicalize_dtype(dtype) - if (dtype in [jnp.float16, jnp.bfloat16] - and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 - and attn_batches % 4 == 0 # batch * heads must be divisor of 4 - ): + if ( + dtype in [jnp.float16, jnp.bfloat16] + and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported + and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 + and attn_batches % 4 == 0 # batch * heads must be divisor of 4 + ): if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) return q_seqlen % batch_per_block == 0 return False @staticmethod - def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-argument + def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-argument """ te_scaled_masked_softmax_forward abstract """ @@ -499,8 +533,8 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ] mask_shape = mask_aval.shape pad_batch = batch = reduce(operator.mul, mask_shape[:-3]) - assert pad_batch in (1, batch) # 1 means broadcast - assert mask_shape[-3] == 1 # 1 means broadcast + assert pad_batch in (1, batch) # 1 means broadcast + assert mask_shape[-3] == 1 # 1 means broadcast assert mask_shape[-2] == q_seqlen assert mask_shape[-1] == k_seqlen @@ -532,8 +566,14 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): args = CustomCallArgsWrapper(out_types, operands, operand_shapes) opaque = transformer_engine_jax.pack_softmax_descriptor( - batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(logits_aval.dtype), - scale_factor) + batch, + pad_batch, + heads, + q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(logits_aval.dtype), + scale_factor, + ) out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False) @@ -542,9 +582,9 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): @staticmethod def impl(logits, mask, scale_factor): assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None - output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(logits, - mask, - scale_factor=scale_factor) + output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind( + logits, mask, scale_factor=scale_factor + ) return output @staticmethod @@ -555,50 +595,60 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): logits_bdim, _ = batch_dims out_bdims = logits_bdim - return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( - logits, mask, scale_factor=scale_factor), out_bdims + return ( + ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( + logits, mask, scale_factor=scale_factor + ), + out_bdims, + ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos) + scale_factor, mesh, arg_infos, result_infos + ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxFwdPrimitive.backward_partition( - ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos) + ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos + ) register_primitive(ScaledMaskedSoftmaxFwdPrimitive) -def scaled_masked_softmax_fwd(logits: jnp.ndarray, mask: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: +def scaled_masked_softmax_fwd( + logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float +) -> jnp.ndarray: """ scaled_masked_softmax_forward wrapper Return FP16/BF16 tensor """ - return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(logits, - mask, - scale_factor=scale_factor) + return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( + logits, mask, scale_factor=scale_factor + ) class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): """ Scaled Masked Softmax Bwd Primitive """ + name = "te_scaled_masked_softmax_backward" multiple_results = False - impl_static_args = (2,) # scale_factor + impl_static_args = (2,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: + def is_kernel_available( + batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype + ) -> bool: """Check Softmax kernel availability based on size""" - return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen, - dtype) + return ScaledSoftmaxFwdPrimitive.is_kernel_available( + batch, heads, q_seqlen, k_seqlen, dtype + ) @staticmethod def abstract(dz_aval, softmax_out_aval, *, scale_factor): @@ -612,83 +662,92 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): """ te_scaled_upper_triang_masked_backward lowering rules """ - out = SoftmaxPrimitive.backward_lowering(ScaledMaskedSoftmaxBwdPrimitive.name, - ctx, - dz, - softmax_out, - scale_factor=scale_factor) + out = SoftmaxPrimitive.backward_lowering( + ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor + ) return out @staticmethod def impl(dz, softmax_out, scale_factor): - return SoftmaxPrimitive.backward_impl(ScaledMaskedSoftmaxBwdPrimitive.inner_primitive, - dz, - softmax_out, - scale_factor=scale_factor) + return SoftmaxPrimitive.backward_impl( + ScaledMaskedSoftmaxBwdPrimitive.inner_primitive, + dz, + softmax_out, + scale_factor=scale_factor, + ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): check_valid_batch_dims(batch_dims) - return SoftmaxPrimitive.backward_batcher(ScaledMaskedSoftmaxBwdPrimitive.outer_primitive, - batched_args, - batch_dims, - scale_factor=scale_factor) + return SoftmaxPrimitive.backward_batcher( + ScaledMaskedSoftmaxBwdPrimitive.outer_primitive, + batched_args, + batch_dims, + scale_factor=scale_factor, + ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos) + scale_factor, mesh, arg_infos, result_infos + ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxBwdPrimitive.backward_partition( - ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos) + ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos + ) register_primitive(ScaledMaskedSoftmaxBwdPrimitive) -def scaled_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: +def scaled_masked_softmax_bwd( + dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float +) -> jnp.ndarray: """ scaled_masked_backward wrapper Return FP16/BF16 tensor """ - return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(dz, - softmax_out, - scale_factor=scale_factor) + return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind( + dz, softmax_out, scale_factor=scale_factor + ) class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): """ Scaled Upper Triang Masked Softmax Fwd Primitive """ + name = "te_scaled_upper_triang_masked_softmax_forward" multiple_results = False - impl_static_args = (1,) # scale_factor + impl_static_args = (1,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: + def is_kernel_available( + batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype + ) -> bool: """Check Softmax kernel availability based on size""" attn_batches = batch * heads dtype = dtypes.canonicalize_dtype(dtype) - if (dtype in [jnp.float16, jnp.bfloat16] - and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 - and attn_batches % 4 == 0 # batch * heads must be divisor of 4 - and k_seqlen == q_seqlen): + if ( + dtype in [jnp.float16, jnp.bfloat16] + and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported + and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 + and attn_batches % 4 == 0 # batch * heads must be divisor of 4 + and k_seqlen == q_seqlen + ): if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) return attn_batches % batch_per_block == 0 return False @staticmethod - def abstract(logits_aval, scale_factor): # pylint: disable=unused-argument + def abstract(logits_aval, scale_factor): # pylint: disable=unused-argument """ te_scaled_upper_triang_masked_softmax_forward abstract """ @@ -702,15 +761,15 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): """ te_scaled_upper_triang_masked_softmax_forward lowering rules """ - return SoftmaxPrimitive.forward_lowering(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, - ctx, - logits, - scale_factor=scale_factor) + return SoftmaxPrimitive.forward_lowering( + ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor + ) @staticmethod def impl(logits, scale_factor): return SoftmaxPrimitive.forward_impl( - ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor) + ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor + ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): @@ -719,18 +778,24 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive, batched_args, batch_dims, - scale_factor=scale_factor) + scale_factor=scale_factor, + ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos) + scale_factor, mesh, arg_infos, result_infos + ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition( - ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, - result_infos) + ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, + scale_factor, + mesh, + arg_infos, + result_infos, + ) register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) @@ -742,25 +807,29 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl Return FP16/BF16 tensor """ return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( - logits, scale_factor=scale_factor) + logits, scale_factor=scale_factor + ) class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): """ Scaled Upper Triang Masked Softmax Bwd Primitive """ + name = "te_scaled_upper_triang_masked_softmax_backward" multiple_results = False - impl_static_args = (2,) # scale_factor + impl_static_args = (2,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: + def is_kernel_available( + batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype + ) -> bool: """Check Softmax kernel availability based on size""" return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available( - batch, heads, q_seqlen, k_seqlen, dtype) + batch, heads, q_seqlen, k_seqlen, dtype + ) @staticmethod def abstract(dz_aval, softmax_out_aval, *, scale_factor): @@ -774,11 +843,13 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): """ te_scaled_upper_triang_masked_backward lowering rules """ - out = SoftmaxPrimitive.backward_lowering(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, - ctx, - dz, - softmax_out, - scale_factor=scale_factor) + out = SoftmaxPrimitive.backward_lowering( + ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, + ctx, + dz, + softmax_out, + scale_factor=scale_factor, + ) return out @@ -788,7 +859,8 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, - scale_factor=scale_factor) + scale_factor=scale_factor, + ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): @@ -797,28 +869,36 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive, batched_args, batch_dims, - scale_factor=scale_factor) + scale_factor=scale_factor, + ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos) + scale_factor, mesh, arg_infos, result_infos + ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition( - ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, - result_infos) + ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, + scale_factor, + mesh, + arg_infos, + result_infos, + ) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) -def scaled_upper_triang_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: +def scaled_upper_triang_masked_softmax_bwd( + dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float +) -> jnp.ndarray: """ scaled_upper_triang_masked_backward wrapper Return FP16/BF16 tensor """ return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( - dz, softmax_out, scale_factor=scale_factor) + dz, softmax_out, scale_factor=scale_factor + ) diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index 818b910034f66007a874c22b479d8bbaa7238053..696342e3d79e2113676a795c58866894bc02f5f2 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -23,27 +23,26 @@ from .misc import ( te_dtype_to_jax_dtype, get_padded_spec, multidim_transpose, - normalize_axis_boundary + normalize_axis_boundary, ) from .activation import ActivationEnum -from ..sharding import ( - all_reduce_max_along_all_axes_except_PP, - all_reduce_sum_along_dp_fsdp -) +from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -__all__ = ['transpose', - 'cast_transpose', - 'dbias_cast_transpose', - 'dact_lu_dbias_cast_transpose', - 'dgated_act_lu_cast_transpose', - ] +__all__ = [ + "transpose", + "cast_transpose", + "dbias_cast_transpose", + "dact_lu_dbias_cast_transpose", + "dgated_act_lu_cast_transpose", +] class TransposePrimitive(BasePrimitive): """ Transpose Primitive """ + name = "te_transpose" multiple_results = False impl_static_args = (1, 2) @@ -55,8 +54,9 @@ class TransposePrimitive(BasePrimitive): """ _transpose abstract """ - transposed_x_shape = multidim_transpose(x_aval.shape, static_axis_boundary, - transpose_axis_boundary) + transposed_x_shape = multidim_transpose( + x_aval.shape, static_axis_boundary, transpose_axis_boundary + ) xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype) return xt_aval @@ -69,7 +69,11 @@ class TransposePrimitive(BasePrimitive): x_aval = ctx.avals_in[0] assert x_aval.dtype in [ - jnp.float32, jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2 + jnp.float32, + jnp.float16, + jnp.bfloat16, + jnp.float8_e4m3fn, + jnp.float8_e5m2, ] ir_x_type = ir.RankedTensorType(x.type) @@ -79,8 +83,9 @@ class TransposePrimitive(BasePrimitive): for i in range(static_axis_boundary + 1): assert ir_x_shape[i] == 1 - transposed_x_shape = multidim_transpose(ir_x_shape, static_axis_boundary, - transpose_axis_boundary) + transposed_x_shape = multidim_transpose( + ir_x_shape, static_axis_boundary, transpose_axis_boundary + ) out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)] operands = [x] @@ -88,10 +93,13 @@ class TransposePrimitive(BasePrimitive): args = CustomCallArgsWrapper(out_types, operands, operand_shapes) te_dtype = jax_dtype_to_te_dtype(x_aval.dtype) - contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), - reduce(operator.mul, ir_x_shape[transpose_axis_boundary:])) - opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, te_dtype, - te_dtype) + contracted_x_shape = ( + reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), + reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), + ) + opaque = transformer_engine_jax.pack_common_descriptor( + contracted_x_shape, te_dtype, te_dtype + ) out = custom_caller(TransposePrimitive.name, args, opaque, False) @@ -103,10 +111,11 @@ class TransposePrimitive(BasePrimitive): tcast_transpose implementation """ assert TransposePrimitive.inner_primitive is not None - transposed_x = \ - TransposePrimitive.inner_primitive.bind(x, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + transposed_x = TransposePrimitive.inner_primitive.bind( + x, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) return transposed_x @staticmethod @@ -115,21 +124,25 @@ class TransposePrimitive(BasePrimitive): assert TransposePrimitive.outer_primitive is not None assert static_axis_boundary < 0 - x, = batched_args - x_bdim, = batch_dims + (x,) = batched_args + (x_bdim,) = batch_dims # Minus batch dim. transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) - transpose_axis_boundary += 1 # Plus batch dim + transpose_axis_boundary += 1 # Plus batch dim out_bdims = x_bdim - return TransposePrimitive.outer_primitive.bind( - x, static_axis_boundary=x_bdim, - transpose_axis_boundary=transpose_axis_boundary), out_bdims + return ( + TransposePrimitive.outer_primitive.bind( + x, static_axis_boundary=x_bdim, transpose_axis_boundary=transpose_axis_boundary + ), + out_bdims, + ) @staticmethod - def infer_sharding_from_operands(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, - result_infos): + def infer_sharding_from_operands( + static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): del result_infos x_spec = get_padded_spec(arg_infos[0]) xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) @@ -145,9 +158,11 @@ class TransposePrimitive(BasePrimitive): arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = transposed_x_sharding - impl = partial(TransposePrimitive.impl, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + impl = partial( + TransposePrimitive.impl, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) return mesh, impl, out_shardings, arg_shardings @@ -155,20 +170,24 @@ class TransposePrimitive(BasePrimitive): register_primitive(TransposePrimitive) -def transpose(x: jnp.ndarray, static_axis_boundary: int, - transpose_axis_boundary: int) -> jnp.ndarray: +def transpose( + x: jnp.ndarray, static_axis_boundary: int, transpose_axis_boundary: int +) -> jnp.ndarray: """ transpose wrapper """ - return TransposePrimitive.outer_primitive.bind(x, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + return TransposePrimitive.outer_primitive.bind( + x, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) class CastTransposePrimitive(BasePrimitive): """ Cast Transpose Primitive """ + name = "te_cast_transpose" multiple_results = True impl_static_args = (4, 5, 6) @@ -176,8 +195,16 @@ class CastTransposePrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, static_axis_boundary, - transpose_axis_boundary): + def abstract( + x_aval, + amax_aval, + scale_aval, + scale_inv_aval, + *, + out_dtype, + static_axis_boundary, + transpose_axis_boundary + ): """ te_cast_transpose_p abstract """ @@ -187,8 +214,9 @@ class CastTransposePrimitive(BasePrimitive): assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - transposed_x_shape = multidim_transpose(x_aval.shape, static_axis_boundary, - transpose_axis_boundary) + transposed_x_shape = multidim_transpose( + x_aval.shape, static_axis_boundary, transpose_axis_boundary + ) casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype) @@ -197,8 +225,9 @@ class CastTransposePrimitive(BasePrimitive): return casted_x_aval, casted_xt_aval, updated_amax_aval @staticmethod - def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, - transpose_axis_boundary): + def lowering( + ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary + ): """ te_cast_transpose_p lowering rules """ @@ -219,8 +248,9 @@ class CastTransposePrimitive(BasePrimitive): ir_scale_shape = ir_amax_shape ir_scale_inv_shape = ir_amax_shape - transposed_x_shape = multidim_transpose(ir_x_shape, static_axis_boundary, - transpose_axis_boundary) + transposed_x_shape = multidim_transpose( + ir_x_shape, static_axis_boundary, transpose_axis_boundary + ) out_types = [ ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), @@ -231,17 +261,19 @@ class CastTransposePrimitive(BasePrimitive): operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), - reduce(operator.mul, ir_x_shape[transpose_axis_boundary:])) - opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype)) + contracted_x_shape = ( + reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), + reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), + ) + opaque = transformer_engine_jax.pack_common_descriptor( + contracted_x_shape, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + ) - out = custom_caller(CastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={1: 2}) + out = custom_caller( + CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2} + ) return out @@ -251,16 +283,21 @@ class CastTransposePrimitive(BasePrimitive): te_cast_transpose implementation """ assert CastTransposePrimitive.inner_primitive is not None - casted_x, casted_transposed_x, updated_amax = \ - CastTransposePrimitive.inner_primitive.bind( - x, amax, scale, scale_inv, out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + casted_x, casted_transposed_x, updated_amax = CastTransposePrimitive.inner_primitive.bind( + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) return casted_x, casted_transposed_x, updated_amax @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, - transpose_axis_boundary): + def batcher( + batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary + ): check_valid_batch_dims(batch_dims) assert CastTransposePrimitive.outer_primitive is not None assert static_axis_boundary < 0 @@ -270,21 +307,26 @@ class CastTransposePrimitive(BasePrimitive): # Minus batch dim. transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) - transpose_axis_boundary += 1 # Plus batch dim + transpose_axis_boundary += 1 # Plus batch dim out_bdims = x_bdim, x_bdim, amax_bdim - return CastTransposePrimitive.outer_primitive.bind( - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=x_bdim, - transpose_axis_boundary=transpose_axis_boundary), out_bdims + return ( + CastTransposePrimitive.outer_primitive.bind( + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=x_bdim, + transpose_axis_boundary=transpose_axis_boundary, + ), + out_bdims, + ) @staticmethod - def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, - arg_infos, result_infos): + def infer_sharding_from_operands( + out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): del out_dtype, result_infos x_spec = get_padded_spec(arg_infos[0]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) @@ -294,8 +336,9 @@ class CastTransposePrimitive(BasePrimitive): return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) @staticmethod - def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, - result_infos): + def partition( + out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): del result_infos x_spec = get_padded_spec(arg_infos[0]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) @@ -306,11 +349,15 @@ class CastTransposePrimitive(BasePrimitive): out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) def sharded_impl(x, amax, scale, scale_inv): - local_cx, local_cxt, local_updated_amax = \ - CastTransposePrimitive.impl(x, amax, scale, scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + local_cx, local_cxt, local_updated_amax = CastTransposePrimitive.impl( + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) return local_cx, local_cxt, global_updated_amax @@ -321,9 +368,15 @@ class CastTransposePrimitive(BasePrimitive): register_primitive(CastTransposePrimitive) -def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - out_dtype: jnp.dtype, static_axis_boundary: int, - transpose_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +def cast_transpose( + x: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: jnp.dtype, + static_axis_boundary: int, + transpose_axis_boundary: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ cast transpose wrapper Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale` @@ -335,13 +388,15 @@ def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_ scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + transpose_axis_boundary=transpose_axis_boundary, + ) class DBiasCastTransposePrimitive(BasePrimitive): """ DBias Cast Transpose Primitive """ + name = "te_dbias_cast_transpose" multiple_results = True # out_dtype, static_axis_boundary, transpose_axis_boundary @@ -350,8 +405,16 @@ class DBiasCastTransposePrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(dz_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, - static_axis_boundary, transpose_axis_boundary): + def abstract( + dz_aval, + amax_aval, + scale_aval, + scale_inv_aval, + *, + out_dtype, + static_axis_boundary, + transpose_axis_boundary + ): """ te_dbias_cast_transpose_p abstract """ @@ -365,18 +428,19 @@ class DBiasCastTransposePrimitive(BasePrimitive): out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype) t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) - dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size) + dbias_shape = (*dz_aval.shape[: static_axis_boundary + 1], gi_hidden_size) dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - wkspace_info, = transformer_engine_jax.get_dbias_ct_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_dbias_ct_workspace_sizes( dz_aval.size // gi_hidden_size, gi_hidden_size, jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype) + jax_dtype_to_te_dtype(out_dtype), + ) + wkspace_aval = dz_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - wkspace_aval = dz_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) return out, t_out, dbias, updated_amax_aval, wkspace_aval @@ -386,13 +450,15 @@ class DBiasCastTransposePrimitive(BasePrimitive): te_dbias_cast_transpose_p outer abstract """ - out, t_out, dbias, updated_amax_aval, _ = \ - DBiasCastTransposePrimitive.abstract(*args, **kwargs) + out, t_out, dbias, updated_amax_aval, _ = DBiasCastTransposePrimitive.abstract( + *args, **kwargs + ) return out, t_out, dbias, updated_amax_aval @staticmethod - def lowering(ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, - transpose_axis_boundary): + def lowering( + ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary + ): """ te_dbias_cast_transpose_p lowering rules """ @@ -412,9 +478,10 @@ class DBiasCastTransposePrimitive(BasePrimitive): ir_amax_shape = ir_amax_type.shape ir_scale_shape = ir_amax_shape ir_scale_inv_shape = ir_amax_shape - transposed_dz_shape = multidim_transpose(ir_dz_shape, static_axis_boundary, - transpose_axis_boundary) - dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_size) + transposed_dz_shape = multidim_transpose( + ir_dz_shape, static_axis_boundary, transpose_axis_boundary + ) + dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size) wkspace_aval = ctx.avals_out[-1] @@ -429,20 +496,21 @@ class DBiasCastTransposePrimitive(BasePrimitive): operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) opaque = transformer_engine_jax.pack_common_wk_descriptor( - contracted_dz_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype)) + contracted_dz_shape, + wkspace_aval.shape, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + ) - out = custom_caller(DBiasCastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={1: 3}) + out = custom_caller( + DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3} + ) return out @staticmethod - def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, - transpose_axis_boundary): + def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary): """ to describe implementation """ @@ -454,12 +522,14 @@ class DBiasCastTransposePrimitive(BasePrimitive): scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + transpose_axis_boundary=transpose_axis_boundary, + ) return out, t_out, dbias, updated_amax @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, - transpose_axis_boundary): + def batcher( + batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary + ): """ to describe batch rules for vmap """ @@ -471,34 +541,41 @@ class DBiasCastTransposePrimitive(BasePrimitive): # Minus batch dim. transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1) - transpose_axis_boundary += 1 # Plus batch dim + transpose_axis_boundary += 1 # Plus batch dim out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim - return DBiasCastTransposePrimitive.outer_primitive.bind( - dz, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=dz_bdim, - transpose_axis_boundary=transpose_axis_boundary), out_bdims + return ( + DBiasCastTransposePrimitive.outer_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=dz_bdim, + transpose_axis_boundary=transpose_axis_boundary, + ), + out_bdims, + ) @staticmethod - def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, - arg_infos, result_infos): + def infer_sharding_from_operands( + out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): del out_dtype, result_infos x_spec = get_padded_spec(arg_infos[0]) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) + ) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) @staticmethod - def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, - result_infos): + def partition( + out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): del result_infos x_spec = get_padded_spec(arg_infos[0]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) @@ -506,12 +583,17 @@ class DBiasCastTransposePrimitive(BasePrimitive): casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) + ) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding, - amax_sharding) + out_shardings = ( + casted_x_sharding, + casted_transposed_x_sharding, + dbias_shaprding, + amax_sharding, + ) def sharded_impl(dz, amax, scale, scale_inv): local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl( @@ -521,7 +603,8 @@ class DBiasCastTransposePrimitive(BasePrimitive): scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + transpose_axis_boundary=transpose_axis_boundary, + ) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) return local_out, local_t_out, global_dbias, global_updated_amax @@ -539,13 +622,14 @@ def dbias_cast_transpose( scale_inv: jnp.ndarray, out_dtype: TEDType, static_axis_boundary: int, - transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + transpose_axis_boundary: int = -1, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ cast transpose dbias partial fusion wrapper Return FP8(inputs), dbias """ if static_axis_boundary < 0: - static_axis_boundary = -1 # means no static axes + static_axis_boundary = -1 # means no static axes return DBiasCastTransposePrimitive.outer_primitive.bind( dz, @@ -554,13 +638,15 @@ def dbias_cast_transpose( scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + transpose_axis_boundary=transpose_axis_boundary, + ) class DActLuDBiasCastTransposePrimitive(BasePrimitive): """ DActLu DBias Cast Transpose Primitive """ + name = "te_dact_lu_dbias_cast_transpose" multiple_results = True # out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum @@ -569,9 +655,18 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, - static_axis_boundary, transpose_axis_boundary, - act_enum): # pylint: disable=unused-argument + def abstract( + dz_aval, + x_aval, + amax_aval, + scale_aval, + scale_inv_aval, + *, + out_dtype, + static_axis_boundary, + transpose_axis_boundary, + act_enum + ): # pylint: disable=unused-argument """ te_dact_lu_dbais_cast_transpose_p abstract """ @@ -584,24 +679,24 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ir_hidden_szie = dz_aval.shape[-1] gi_hidden_size = x_aval.shape[-1] assert ir_hidden_szie == gi_hidden_size - t_shape = multidim_transpose(x_aval.shape, - static_axis_boundary, transpose_axis_boundary) + t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary) out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) - dbias_shape = (*x_aval.shape[:static_axis_boundary + 1], gi_hidden_size) + dbias_shape = (*x_aval.shape[: static_axis_boundary + 1], gi_hidden_size) dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - wkspace_info, = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes( x_aval.size // gi_hidden_size, gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), ) - wkspace_aval = x_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + wkspace_aval = x_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) return out, t_out, dbias, updated_amax_aval, wkspace_aval @@ -611,13 +706,25 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): te_dact_lu_dbais_cast_transpose_p outer abstract """ - out, t_out, dbias, updated_amax_aval, _ = \ - DActLuDBiasCastTransposePrimitive.abstract(*args, **kwargs) + out, t_out, dbias, updated_amax_aval, _ = DActLuDBiasCastTransposePrimitive.abstract( + *args, **kwargs + ) return out, t_out, dbias, updated_amax_aval @staticmethod - def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, - transpose_axis_boundary, act_enum): + def lowering( + ctx, + dz, + x, + amax, + scale, + scale_inv, + *, + out_dtype, + static_axis_boundary, + transpose_axis_boundary, + act_enum + ): """ te_dgated_act_lu_cast_transpose_p lowering rules """ @@ -643,9 +750,10 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ir_amax_shape = ir_amax_type.shape ir_scale_shape = ir_amax_shape ir_scale_inv_shape = ir_amax_shape - transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, - transpose_axis_boundary) - dbias_shape = (*x_shape[:static_axis_boundary + 1], ir_hidden_szie) + transposed_x_shape = multidim_transpose( + x_shape, static_axis_boundary, transpose_axis_boundary + ) + dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie) wkspace_aval = ctx.avals_out[-1] @@ -660,21 +768,36 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) opaque = transformer_engine_jax.pack_common_wk_descriptor( - contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - act_enum) + contracted_x_shape, + wkspace_aval.shape, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + act_enum, + ) - out = custom_caller(DActLuDBiasCastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={2: 3}) + out = custom_caller( + DActLuDBiasCastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={2: 3}, + ) return out @staticmethod - def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, - transpose_axis_boundary, act_enum): + def impl( + dz, + x, + amax, + scale, + scale_inv, + out_dtype, + static_axis_boundary, + transpose_axis_boundary, + act_enum, + ): """ to describe implementation """ @@ -688,12 +811,20 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, transpose_axis_boundary=transpose_axis_boundary, - act_enum=act_enum) + act_enum=act_enum, + ) return out, t_out, dbias, updated_amax @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, - transpose_axis_boundary, act_enum): + def batcher( + batched_args, + batch_dims, + *, + out_dtype, + static_axis_boundary, + transpose_axis_boundary, + act_enum + ): """ to describe batch rules for vmap """ @@ -705,36 +836,55 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): # Minus batch dim. transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) - transpose_axis_boundary += 1 # Plus batch dim + transpose_axis_boundary += 1 # Plus batch dim out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim - return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=x_bdim, - transpose_axis_boundary=transpose_axis_boundary, - act_enum=act_enum), out_bdims + return ( + DActLuDBiasCastTransposePrimitive.outer_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=x_bdim, + transpose_axis_boundary=transpose_axis_boundary, + act_enum=act_enum, + ), + out_bdims, + ) @staticmethod - def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, - act_enum, mesh, arg_infos, result_infos): + def infer_sharding_from_operands( + out_dtype, + static_axis_boundary, + transpose_axis_boundary, + act_enum, + mesh, + arg_infos, + result_infos, + ): del out_dtype, result_infos, act_enum x_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) + ) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) @staticmethod - def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, - act_enum, mesh, arg_infos, result_infos): + def partition( + out_dtype, + static_axis_boundary, + transpose_axis_boundary, + act_enum, + mesh, + arg_infos, + result_infos, + ): del result_infos x_spec = get_padded_spec(arg_infos[1]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) @@ -742,25 +892,32 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) + ) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding, - amax_sharding) + out_shardings = ( + casted_x_sharding, + casted_transposed_x_sharding, + dbias_shaprding, + amax_sharding, + ) def sharded_impl(dz, x, amax, scale, scale_inv): - local_out, local_t_out, local_dbias, local_amax =\ - DActLuDBiasCastTransposePrimitive.impl( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - act_enum=act_enum) + local_out, local_t_out, local_dbias, local_amax = ( + DActLuDBiasCastTransposePrimitive.impl( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + act_enum=act_enum, + ) + ) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) return local_out, local_t_out, global_dbias, global_updated_amax @@ -780,15 +937,15 @@ def dact_lu_dbias_cast_transpose( out_dtype: TEDType, static_axis_boundary: int, transpose_axis_boundary: int = -1, - activation_type: Sequence[Union[str, Callable]] = ('gelu',) - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + activation_type: Sequence[Union[str, Callable]] = ("gelu",), +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ cast transpose dact_lu and dbias fusion wrapper Return FP8(dact_lu(inputs)), dbias ONLY support non-gated activation type """ if static_axis_boundary < 0: - static_axis_boundary = -1 # means no static axes + static_axis_boundary = -1 # means no static axes act_type_id = ActivationEnum[activation_type] return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( @@ -800,29 +957,40 @@ def dact_lu_dbias_cast_transpose( out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, transpose_axis_boundary=transpose_axis_boundary, - act_enum=act_type_id) + act_enum=act_type_id, + ) class DgatedActLuCastTransposePrimitive(BasePrimitive): """ Dgated ActLu Cast Transpose Primitive """ + name = "te_dgated_act_lu_cast_transpose" multiple_results = True - impl_static_args = (5, 6, 7) # out_dtype, static_axis_boundary, act_enum + impl_static_args = (5, 6, 7) # out_dtype, static_axis_boundary, act_enum inner_primitive = None outer_primitive = None @staticmethod - def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, - static_axis_boundary, act_enum): # pylint: disable=unused-argument + def abstract( + dz_aval, + x_aval, + amax_aval, + scale_aval, + scale_inv_aval, + *, + out_dtype, + static_axis_boundary, + act_enum + ): # pylint: disable=unused-argument """ te_dgated_act_lu_cast_transpose_p abstract """ dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dtype - assert x_aval.shape[-2] == 2 # Linear + GeLU + assert x_aval.shape[-2] == 2 # Linear + GeLU assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 @@ -853,7 +1021,7 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive): dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) x_batch_size = reduce(operator.mul, x_shape[:-2]) assert dz_batch_szie == x_batch_size - assert x_shape[-2] == 2 # Linear + GeLU + assert x_shape[-2] == 2 # Linear + GeLU ir_hidden_szie = ir_dz_shape[-1] gi_hidden_size = x_shape[-1] assert ir_hidden_szie == gi_hidden_size @@ -877,13 +1045,16 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive): contracted_x_shape, jax_dtype_to_te_dtype(dz_aval.dtype), jax_dtype_to_te_dtype(out_dtype), - act_enum) + act_enum, + ) - out = custom_caller(DgatedActLuCastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={2: 2}) + out = custom_caller( + DgatedActLuCastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={2: 2}, + ) return out @@ -901,7 +1072,8 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive): scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - act_enum=act_enum) + act_enum=act_enum, + ) return out, t_out, updated_amax @staticmethod @@ -916,14 +1088,24 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive): x_bdim, _, amax_bdim, _, _ = batch_dims out_bdims = x_bdim, x_bdim, amax_bdim - return DgatedActLuCastTransposePrimitive.outer_primitive.bind( - dz, x, amax, scale, scale_inv, out_dtype=out_dtype, - static_axis_boundary=x_bdim, - act_enum=act_enum), out_bdims + return ( + DgatedActLuCastTransposePrimitive.outer_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=x_bdim, + act_enum=act_enum, + ), + out_bdims, + ) @staticmethod - def infer_sharding_from_operands(out_dtype, static_axis_boundary, act_enum, - mesh, arg_infos, result_infos): + def infer_sharding_from_operands( + out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos + ): del out_dtype, result_infos, act_enum x_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) @@ -933,8 +1115,7 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive): return (out_sharding, tranposed_out_sharding, amax_sharding) @staticmethod - def partition(out_dtype, static_axis_boundary, act_enum, - mesh, arg_infos, result_infos): + def partition(out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos): del result_infos x_spec = get_padded_spec(arg_infos[1]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) @@ -954,7 +1135,8 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive): scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - act_enum=act_enum) + act_enum=act_enum, + ) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) return local_out, local_t_out, global_updated_amax @@ -965,11 +1147,15 @@ register_primitive(DgatedActLuCastTransposePrimitive) def dgated_act_lu_cast_transpose( - dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, - scale_inv: jnp.ndarray, out_dtype: TEDType, + dz: jnp.ndarray, + x: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: TEDType, static_axis_boundary: int, - activation_type: Sequence[Union[str, Callable]] - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + activation_type: Sequence[Union[str, Callable]], +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ cast transpose d_gated_act_lu fusion wrapper Return FP8(dgated_act_lu(inputs)) @@ -983,4 +1169,5 @@ def dgated_act_lu_cast_transpose( scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - act_enum=act_type_id) + act_enum=act_type_id, + ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index a08f67a0384722ae00586342daa7295a8411c228..6e74f816af4edeaaf303f94678e95982598aed51 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -7,64 +7,59 @@ #ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ #define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include + #include #include #include -#include +#include #include #include -#include - -#include -#include -#include +#include -#include -#include -#include #include "common/common.h" #include "common/util/logging.h" #include "utils.h" -#include -#include -#include -#include - namespace transformer_engine { namespace jax { constexpr int kMaxNumDim = 8; - // TODO: Rename Shape to ??? struct Shape { - int num_dim; - size_t dims[kMaxNumDim]; + int num_dim; + size_t dims[kMaxNumDim]; - void from_vector(const std::vector &shape); + void from_vector(const std::vector &shape); - std::vector to_vector() const; + std::vector to_vector() const; }; // Phuong: These 3 functions need to stay in the header file for compilation purpose // 1. -inline bool use_fp8(DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; -} +inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // 2. template pybind11::bytes PackOpaque(const T &descriptor) { - auto str = std::string(reinterpret_cast(&descriptor), sizeof(T)); - return pybind11::bytes(str); + auto str = std::string(reinterpret_cast(&descriptor), sizeof(T)); + return pybind11::bytes(str); } // 3. template const T *UnpackOpaque(const char *opaque, size_t opaque_len) { - if (opaque_len != sizeof(T)) { - throw std::runtime_error("Invalid opaque object size"); - } - return reinterpret_cast(opaque); + if (opaque_len != sizeof(T)) { + throw std::runtime_error("Invalid opaque object size"); + } + return reinterpret_cast(opaque); } std::vector MakeShapeVector(NVTEShape shape); @@ -72,45 +67,45 @@ std::vector MakeShapeVector(NVTEShape shape); // Packing struct CustomCallCommonDescriptor { - Shape shape; - DType in_dtype; - DType out_dtype; - size_t act_enum; + Shape shape; + DType in_dtype; + DType out_dtype; + size_t act_enum; }; pybind11::bytes PackCustomCallCommonDescriptor(const std::vector &shape, DType in_dtype, DType out_dtype, size_t act_enum = 0); struct CustomCallCommonWkDescriptor { - Shape shape; - Shape wkshape; - DType in_dtype; - DType out_dtype; - DType wk_dtype; - size_t act_enum; + Shape shape; + Shape wkshape; + DType in_dtype; + DType out_dtype; + DType wk_dtype; + size_t act_enum; }; pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector &shape, - const std::vector &wkshape, - DType in_dtype, DType out_dtype, DType wk_dtype, + const std::vector &wkshape, DType in_dtype, + DType out_dtype, DType wk_dtype, size_t act_enum = 0); struct CustomCallNormDescriptor { - size_t batch_size; - size_t hidden_size; - size_t wkspace_size; - size_t barrier_size; - Shape dgamma_part_shape; - Shape dbeta_part_shape; - DType x_dtype; - DType w_dtype; - DType wkspace_dtype; - DType barrier_dtype; - DType dgamma_part_dtype; - DType dbeta_part_dtype; - bool zero_centered_gamma; - float eps; - int sm_margin; + size_t batch_size; + size_t hidden_size; + size_t wkspace_size; + size_t barrier_size; + Shape dgamma_part_shape; + Shape dbeta_part_shape; + DType x_dtype; + DType w_dtype; + DType wkspace_dtype; + DType barrier_dtype; + DType dgamma_part_dtype; + DType dbeta_part_dtype; + bool zero_centered_gamma; + float eps; + int sm_margin; }; pybind11::bytes PackCustomCallNormDescriptor( @@ -120,13 +115,13 @@ pybind11::bytes PackCustomCallNormDescriptor( DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin); struct SoftmaxDescriptor { - size_t batch_size; - size_t padding_size; - size_t head_dim; - size_t q_seqlen; - size_t k_seqlen; - DType dtype; - float scale_factor; + size_t batch_size; + size_t padding_size; + size_t head_dim; + size_t q_seqlen; + size_t k_seqlen; + DType dtype; + float scale_factor; }; pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size, @@ -134,23 +129,23 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin DType dtype, float scale_factor); struct CustomCallFusedAttnDescriptor { - size_t input_batch; - size_t bias_batch; - size_t q_max_seqlen; - size_t kv_max_seqlen; - size_t attn_heads; - size_t num_gqa_groups; - size_t bias_heads; - size_t head_dim; - size_t wkspace_size; - float scaling_factor; - float dropout_probability; - NVTE_Bias_Type bias_type; - NVTE_Mask_Type mask_type; - NVTE_QKV_Layout qkv_layout; - DType dtype; - DType wkspace_dtype; - bool is_training; + size_t input_batch; + size_t bias_batch; + size_t q_max_seqlen; + size_t kv_max_seqlen; + size_t attn_heads; + size_t num_gqa_groups; + size_t bias_heads; + size_t head_dim; + size_t wkspace_size; + float scaling_factor; + float dropout_probability; + NVTE_Bias_Type bias_type; + NVTE_Mask_Type mask_type; + NVTE_QKV_Layout qkv_layout; + DType dtype; + DType wkspace_dtype; + bool is_training; }; pybind11::bytes PackCustomCallFusedAttnDescriptor( @@ -167,10 +162,9 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype); + DType in_dtype, DType out_dtype); -void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); +void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); // Activation @@ -183,13 +177,13 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype); + DType in_dtype, DType out_dtype); void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); + size_t opaque_len); void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); + size_t opaque_len); // Normalization diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index e5c274571d1712d09f889f3712cc6a9f8f82f7a9..84e3ef2e8957c44a79a86d8a7c43bbef5603725d 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -4,8 +4,9 @@ * See LICENSE for license information. ************************************************************************/ -#include "jax/csrc/extensions.h" #include "transformer_engine/activation.h" + +#include "jax/csrc/extensions.h" #include "transformer_engine/transpose.h" namespace transformer_engine { @@ -13,332 +14,330 @@ namespace jax { size_t get_activation_len(NVTE_Activation_Type activation_enum) { switch (activation_enum) { - case NVTE_Activation_Type::GELU: return 1; - case NVTE_Activation_Type::GEGLU: return 2; - case NVTE_Activation_Type::SILU: return 1; - case NVTE_Activation_Type::SWIGLU: return 2; - case NVTE_Activation_Type::RELU: return 1; - case NVTE_Activation_Type::REGLU: return 2; - case NVTE_Activation_Type::QGELU: return 1; - case NVTE_Activation_Type::QGEGLU: return 2; - case NVTE_Activation_Type::SRELU: return 1; - case NVTE_Activation_Type::SREGLU: return 2; + case NVTE_Activation_Type::GELU: + return 1; + case NVTE_Activation_Type::GEGLU: + return 2; + case NVTE_Activation_Type::SILU: + return 1; + case NVTE_Activation_Type::SWIGLU: + return 2; + case NVTE_Activation_Type::RELU: + return 1; + case NVTE_Activation_Type::REGLU: + return 2; + case NVTE_Activation_Type::QGELU: + return 1; + case NVTE_Activation_Type::QGEGLU: + return 2; + case NVTE_Activation_Type::SRELU: + return 1; + case NVTE_Activation_Type::SREGLU: + return 2; default: NVTE_ERROR("Unsupported ActivationEnum"); break; - return -1; + return -1; } } void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, - cudaStream_t stream, float *scale_inverse, float *amax, void *output, - NVTE_Activation_Type act_enum) { - auto act_len = get_activation_len(act_enum); - auto input_shape = std::vector{m, n * act_len}; - auto output_shape = std::vector{m, n}; - auto input_tensor = TensorWrapper(input, input_shape, - static_cast(in_dtype)); - auto output_tensor = TensorWrapper(output, output_shape, - static_cast(out_dtype), amax, - scale, scale_inverse); - switch (act_enum) { + cudaStream_t stream, float *scale_inverse, float *amax, void *output, + NVTE_Activation_Type act_enum) { + auto act_len = get_activation_len(act_enum); + auto input_shape = std::vector{m, n * act_len}; + auto output_shape = std::vector{m, n}; + auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); + auto output_tensor = TensorWrapper(output, output_shape, static_cast(out_dtype), amax, + scale, scale_inverse); + switch (act_enum) { case NVTE_Activation_Type::GELU: - nvte_gelu(input_tensor.data(), output_tensor.data(), stream); - break; + nvte_gelu(input_tensor.data(), output_tensor.data(), stream); + break; case NVTE_Activation_Type::GEGLU: - nvte_geglu(input_tensor.data(), output_tensor.data(), stream); - break; + nvte_geglu(input_tensor.data(), output_tensor.data(), stream); + break; case NVTE_Activation_Type::SILU: - nvte_silu(input_tensor.data(), output_tensor.data(), stream); - break; + nvte_silu(input_tensor.data(), output_tensor.data(), stream); + break; case NVTE_Activation_Type::SWIGLU: - nvte_swiglu(input_tensor.data(), output_tensor.data(), stream); - break; - case NVTE_Activation_Type::RELU: - nvte_relu(input_tensor.data(), output_tensor.data(), stream); - break; - case NVTE_Activation_Type::REGLU: - nvte_reglu(input_tensor.data(), output_tensor.data(), stream); - break; - case NVTE_Activation_Type::QGELU: - nvte_qgelu(input_tensor.data(), output_tensor.data(), stream); - break; - case NVTE_Activation_Type::QGEGLU: - nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream); - break; - case NVTE_Activation_Type::SRELU: - nvte_srelu(input_tensor.data(), output_tensor.data(), stream); - break; - case NVTE_Activation_Type::SREGLU: - nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); - break; - default: - NVTE_ERROR("Unsupported ActivationEnum"); - break; - } + nvte_swiglu(input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::RELU: + nvte_relu(input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::REGLU: + nvte_reglu(input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGELU: + nvte_qgelu(input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGEGLU: + nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SRELU: + nvte_srelu(input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SREGLU: + nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + } } void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *output = buffers[1]; + auto *input = buffers[0]; + auto *output = buffers[1]; - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto act_enum = static_cast(desc.act_enum);; + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum); + ; - ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, - nullptr, nullptr, output, act_enum); + ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output, + act_enum); } void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - float *amax = reinterpret_cast(buffers[1]); - float *scale = reinterpret_cast(buffers[2]); - float *scale_inv = reinterpret_cast(buffers[3]); - auto *output = buffers[4]; - float *amax_out = reinterpret_cast(buffers[5]); - assert(amax == amax_out); + auto *input = buffers[0]; + float *amax = reinterpret_cast(buffers[1]); + float *scale = reinterpret_cast(buffers[2]); + float *scale_inv = reinterpret_cast(buffers[3]); + auto *output = buffers[4]; + float *amax_out = reinterpret_cast(buffers[5]); + assert(amax == amax_out); - const auto &desc = *UnpackOpaque(opaque, opaque_len); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; - } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto act_enum = static_cast(desc.act_enum);; + const auto &desc = *UnpackOpaque(opaque, opaque_len); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum); + ; - ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, - scale_inv, amax_out, output, act_enum); + ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output, + act_enum); } void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *act_input = buffers[1]; - auto *output = buffers[2]; + auto *input = buffers[0]; + auto *act_input = buffers[1]; + auto *output = buffers[2]; - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto act_enum = static_cast(desc.act_enum);; + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum); + ; - auto act_len = get_activation_len(act_enum); - auto input_shape = std::vector{m, n}; - auto act_input_shape = std::vector{m, n * act_len}; - auto output_shape = std::vector{m, n * act_len}; + auto act_len = get_activation_len(act_enum); + auto input_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n * act_len}; + auto output_shape = std::vector{m, n * act_len}; - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); - auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); + auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); - switch (act_enum) { - case NVTE_Activation_Type::GELU: - nvte_dgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - case NVTE_Activation_Type::GEGLU: - nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - case NVTE_Activation_Type::SILU: - nvte_dsilu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - case NVTE_Activation_Type::SWIGLU: - nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - case NVTE_Activation_Type::RELU: - nvte_drelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - case NVTE_Activation_Type::REGLU: - nvte_dreglu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - case NVTE_Activation_Type::QGELU: - nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - case NVTE_Activation_Type::QGEGLU: - nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - case NVTE_Activation_Type::SRELU: - nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - case NVTE_Activation_Type::SREGLU: - nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), stream); - break; - default: - NVTE_ERROR("Unsupported ActivationEnum"); - break; - } + switch (act_enum) { + case NVTE_Activation_Type::GELU: + nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::GEGLU: + nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SILU: + nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SWIGLU: + nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::RELU: + nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::REGLU: + nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGELU: + nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGEGLU: + nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SRELU: + nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SREGLU: + nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + } } pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype) { - auto input_shape = std::vector{batch_size, hidden_size}; - auto dact_input_shape = std::vector{batch_size, hidden_size}; - auto output_shape = std::vector{batch_size, hidden_size}; - auto output_trans_shape = std::vector{hidden_size, batch_size}; - auto dbias_shape = std::vector{hidden_size}; + DType in_dtype, DType out_dtype) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto dact_input_shape = std::vector{batch_size, hidden_size}; + auto output_shape = std::vector{batch_size, hidden_size}; + auto output_trans_shape = std::vector{hidden_size, batch_size}; + auto dbias_shape = std::vector{hidden_size}; - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); - auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); - auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype); + auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); + auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); + auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); - TensorWrapper dummy_workspace; + TensorWrapper dummy_workspace; - // For now, all dbias_dact(-s) have the same workspace size - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), dummy_workspace.data(), nullptr); + // For now, all dbias_dact(-s) have the same workspace size + nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), dummy_workspace.data(), nullptr); - auto work_shape = MakeShapeVector(dummy_workspace.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); + auto work_shape = MakeShapeVector(dummy_workspace.shape()); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); } void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - auto *input = buffers[0]; - auto *act_input = buffers[1]; - float *amax = reinterpret_cast(buffers[2]); - float *scale = reinterpret_cast(buffers[3]); - float *scale_inv = reinterpret_cast(buffers[4]); - auto *output = buffers[5]; - auto *output_trans = buffers[6]; - auto *dbias = buffers[7]; - float *amax_out = reinterpret_cast(buffers[8]); - void *workspace_ptr = buffers[9]; + size_t opaque_len) { + auto *input = buffers[0]; + auto *act_input = buffers[1]; + float *amax = reinterpret_cast(buffers[2]); + float *scale = reinterpret_cast(buffers[3]); + float *scale_inv = reinterpret_cast(buffers[4]); + auto *output = buffers[5]; + auto *output_trans = buffers[6]; + auto *dbias = buffers[7]; + float *amax_out = reinterpret_cast(buffers[8]); + void *workspace_ptr = buffers[9]; - const auto &desc = *UnpackOpaque(opaque, opaque_len); - assert(amax == amax_out); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; - } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto act_enum = static_cast(desc.act_enum);; - auto input_shape = std::vector{m, n}; - auto act_input_shape = std::vector{m, n}; - auto output_shape = std::vector{m, n}; - auto output_trans_shape = std::vector{n, m}; - auto dbias_shape = std::vector{n}; + const auto &desc = *UnpackOpaque(opaque, opaque_len); + assert(amax == amax_out); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum); + ; + auto input_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n}; + auto output_shape = std::vector{m, n}; + auto output_trans_shape = std::vector{n, m}; + auto dbias_shape = std::vector{n}; - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); - auto output_tensor = - TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); + auto output_tensor = + TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); - auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); + auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); - switch (act_enum) { - case NVTE_Activation_Type::GELU: - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); - break; - case NVTE_Activation_Type::SILU: - nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); - break; - case NVTE_Activation_Type::RELU: - nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); - break; - case NVTE_Activation_Type::QGELU: - nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); - break; - case NVTE_Activation_Type::SRELU: - nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); - break; - default: - NVTE_ERROR("Unsupported ActivationEnum"); - break; - } + switch (act_enum) { + case NVTE_Activation_Type::GELU: + nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); + break; + case NVTE_Activation_Type::SILU: + nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); + break; + case NVTE_Activation_Type::RELU: + nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); + break; + case NVTE_Activation_Type::QGELU: + nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); + break; + case NVTE_Activation_Type::SRELU: + nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + } } void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - auto *input = buffers[0]; - auto *act_input = buffers[1]; - float *amax = reinterpret_cast(buffers[2]); - float *scale = reinterpret_cast(buffers[3]); - float *scale_inv = reinterpret_cast(buffers[4]); - auto *output = buffers[5]; - auto *output_trans = buffers[6]; - float *amax_out = reinterpret_cast(buffers[7]); + size_t opaque_len) { + auto *input = buffers[0]; + auto *act_input = buffers[1]; + float *amax = reinterpret_cast(buffers[2]); + float *scale = reinterpret_cast(buffers[3]); + float *scale_inv = reinterpret_cast(buffers[4]); + auto *output = buffers[5]; + auto *output_trans = buffers[6]; + float *amax_out = reinterpret_cast(buffers[7]); - const auto &desc = *UnpackOpaque(opaque, opaque_len); - assert(amax == amax_out); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; - } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto act_enum = static_cast(desc.act_enum);; - auto input_shape = desc.shape.to_vector(); - auto act_input_shape = std::vector{m, n * 2}; - auto output_shape = std::vector{m, n * 2}; - auto output_trans_shape = std::vector{n * 2, m}; + const auto &desc = *UnpackOpaque(opaque, opaque_len); + assert(amax == amax_out); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum); + ; + auto input_shape = desc.shape.to_vector(); + auto act_input_shape = std::vector{m, n * 2}; + auto output_shape = std::vector{m, n * 2}; + auto output_trans_shape = std::vector{n * 2, m}; - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); - auto output_tensor = - TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); + auto output_tensor = + TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); - switch (act_enum) { - case NVTE_Activation_Type::GEGLU: - nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - stream); - break; - case NVTE_Activation_Type::SWIGLU: - nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - stream); - break; - case NVTE_Activation_Type::REGLU: - nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - stream); - break; - case NVTE_Activation_Type::QGEGLU: - nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - stream); - break; - case NVTE_Activation_Type::SREGLU: - nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - stream); - break; - default: - NVTE_ERROR("Unsupported ActivationEnum"); - break; - } + switch (act_enum) { + case NVTE_Activation_Type::GEGLU: + nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), stream); + break; + case NVTE_Activation_Type::SWIGLU: + nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), stream); + break; + case NVTE_Activation_Type::REGLU: + nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGEGLU: + nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), stream); + break; + case NVTE_Activation_Type::SREGLU: + nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + } } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 58c9915b4ceb6a7300aa96a60c83bf0e3d1c8dd0..9f332d9a29fe58f10ffbc1ec8d424b8ed9ac4654 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -16,11 +16,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim) { - auto backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, - head_dim); - return backend; + auto backend = nvte_get_fused_attn_backend( + static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, + mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, + head_dim); + return backend; } /* @@ -34,43 +34,43 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend, void *softmax_buf, void *rng_state_buf = nullptr, void *bias_buf = nullptr) { - auto input_batch = desc->input_batch; - auto bias_batch = desc->bias_batch; - auto attn_heads = desc->attn_heads; - auto bias_heads = desc->bias_heads; - auto q_max_seqlen = desc->q_max_seqlen; - auto kv_max_seqlen = desc->kv_max_seqlen; - - // all backends need softmax but expect different shapes/dtypes - // start with the max512 sequence length softmax shape/dtype and correct later - tensor_pack->size = 1; - Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); - softmax_aux->data.dptr = softmax_buf; - softmax_aux->data.shape = - std::vector{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen}; - softmax_aux->data.dtype = desc->dtype; - - // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax - if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { - tensor_pack->size = 2; - Tensor *rng_state_aux = reinterpret_cast(tensor_pack->tensors[1]); - rng_state_aux->data.dptr = rng_state_buf; - rng_state_aux->data.shape = std::vector{2}; - rng_state_aux->data.dtype = DType::kInt64; - // correct softmax shape/dtype - softmax_aux->data.shape.at(3) = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1} - softmax_aux->data.dtype = DType::kFloat32; - - // include bias if enabled - if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) { - tensor_pack->size = 3; - Tensor *bias_aux = reinterpret_cast(tensor_pack->tensors[2]); - bias_aux->data.dptr = bias_buf; - bias_aux->data.shape = - std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - bias_aux->data.dtype = desc->dtype; - } + auto input_batch = desc->input_batch; + auto bias_batch = desc->bias_batch; + auto attn_heads = desc->attn_heads; + auto bias_heads = desc->bias_heads; + auto q_max_seqlen = desc->q_max_seqlen; + auto kv_max_seqlen = desc->kv_max_seqlen; + + // all backends need softmax but expect different shapes/dtypes + // start with the max512 sequence length softmax shape/dtype and correct later + tensor_pack->size = 1; + Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); + softmax_aux->data.dptr = softmax_buf; + softmax_aux->data.shape = + std::vector{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen}; + softmax_aux->data.dtype = desc->dtype; + + // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + tensor_pack->size = 2; + Tensor *rng_state_aux = reinterpret_cast(tensor_pack->tensors[1]); + rng_state_aux->data.dptr = rng_state_buf; + rng_state_aux->data.shape = std::vector{2}; + rng_state_aux->data.dtype = DType::kInt64; + // correct softmax shape/dtype + softmax_aux->data.shape.at(3) = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1} + softmax_aux->data.dtype = DType::kFloat32; + + // include bias if enabled + if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) { + tensor_pack->size = 3; + Tensor *bias_aux = reinterpret_cast(tensor_pack->tensors[2]); + bias_aux->data.dptr = bias_buf; + bias_aux->data.shape = + std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; + bias_aux->data.dtype = desc->dtype; } + } } /* @@ -85,19 +85,19 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const CustomCallFusedAttnDescriptor *desc, NVTE_Fused_Attn_Backend backend, void *softmax_buf, void *rng_state_buf, void *bias_buf) { - // Backward calls put everything into the tensor pack for every backend - // so we set dummy bias_type and backend choices here to follow the correct code path - auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; - auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend, - softmax_buf, rng_state_buf, bias_buf); - - // correct softmax shape for max512 sequence length kernel - if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); - softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} - softmax_aux->data.dtype = desc->dtype; - } + // Backward calls put everything into the tensor pack for every backend + // so we set dummy bias_type and backend choices here to follow the correct code path + auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; + auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend, softmax_buf, + rng_state_buf, bias_buf); + + // correct softmax shape for max512 sequence length kernel + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); + softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} + softmax_aux->data.dtype = desc->dtype; + } } pybind11::tuple GetFusedAttnForwardWorkspaceSizes( @@ -105,77 +105,75 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) { - // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - - // For kv_packed - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - - // For separate q, k, v - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_shape = k_shape; - auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); - - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - - // F16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); - - auto q_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - - auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); - - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - - auto dummy_ragged_offset_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - TensorWrapper query_workspace_tensor; - if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { - assert(q_max_seqlen == kv_max_seqlen); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), - q_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), - nullptr); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, query_workspace_tensor.data(), nullptr); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), nullptr); - } else { - NVTE_ERROR("Unsupported QKVLayout."); - } - - auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); + // For qkv_packed + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + + // For kv_packed + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + + // For separate q, k, v + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + + auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; + auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); + + // F16 doesn't use this tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); + + auto q_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); + + auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); + + NVTETensorPack aux_output_tensors; + nvte_tensor_pack_create(&aux_output_tensors); + + auto dummy_ragged_offset_tensor = + TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); + TensorWrapper query_workspace_tensor; + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + assert(q_max_seqlen == kv_max_seqlen); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + nvte_fused_attn_fwd_kvpacked( + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), + nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else { + NVTE_ERROR("Unsupported QKVLayout."); + } + + auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( @@ -183,214 +181,207 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) { - auto output_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; - auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); - auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); - - auto bias_shape = std::vector{1, attn_heads, q_max_seqlen, kv_max_seqlen}; - auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - - // F16 doesn't use s_tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - - auto q_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); - - NVTETensorPack aux_input_tensors; - nvte_tensor_pack_create(&aux_input_tensors); - - TensorWrapper query_workspace_tensor; - - auto dummy_ragged_offset_tensor = - TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); - if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { - assert(q_max_seqlen == kv_max_seqlen); - auto qkv_shape = std::vector{batch_size * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - nvte_fused_attn_bwd_qkvpacked( - qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { - auto q_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = - std::vector{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), - nullptr); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { - auto q_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto k_shape = std::vector{batch_size * kv_max_seqlen, num_gqa_groups, head_dim}; - auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_shape = k_shape; - auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); - auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, - dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), - nullptr); - } else { - NVTE_ERROR("Unsupported QKVLayout."); - } + auto output_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; + auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); + auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); + + auto bias_shape = std::vector{1, attn_heads, q_max_seqlen, kv_max_seqlen}; + auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); + + // F16 doesn't use s_tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + + auto q_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); + + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + + TensorWrapper query_workspace_tensor; - auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); + auto dummy_ragged_offset_tensor = + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + assert(q_max_seqlen == kv_max_seqlen); + auto qkv_shape = std::vector{batch_size * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + nvte_fused_attn_bwd_qkvpacked( + qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + auto q_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto kv_shape = std::vector{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + nvte_fused_attn_bwd_kvpacked( + q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), + nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + auto q_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto k_shape = std::vector{batch_size * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else { + NVTE_ERROR("Unsupported QKVLayout."); + } + + auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - /* Input buffers from XLA */ - /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ - void *bias = buffers[3]; - void *q_cu_seqlens = buffers[4]; - void *kv_cu_seqlens = buffers[5]; - void *seed = buffers[6]; - - /* Output buffer from XLA */ - void *output = buffers[7]; - void *softmax_aux = buffers[8]; - void *rng_state = buffers[9]; - void *workspace = buffers[10]; - - /* Descriptor */ - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - auto qkv_layout = descriptor.qkv_layout; - auto dtype = descriptor.dtype; - - /* Input tensors */ + const CustomCallFusedAttnDescriptor &descriptor = + *UnpackOpaque(opaque, opaque_len); + + /* Input buffers from XLA */ + /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ + void *bias = buffers[3]; + void *q_cu_seqlens = buffers[4]; + void *kv_cu_seqlens = buffers[5]; + void *seed = buffers[6]; + + /* Output buffer from XLA */ + void *output = buffers[7]; + void *softmax_aux = buffers[8]; + void *rng_state = buffers[9]; + void *workspace = buffers[10]; + + /* Descriptor */ + auto input_batch = descriptor.input_batch; + auto bias_batch = descriptor.bias_batch; + auto q_max_seqlen = descriptor.q_max_seqlen; + auto kv_max_seqlen = descriptor.kv_max_seqlen; + auto attn_heads = descriptor.attn_heads; + auto num_gqa_groups = descriptor.num_gqa_groups; + auto bias_heads = descriptor.bias_heads; + auto head_dim = descriptor.head_dim; + auto scaling_factor = descriptor.scaling_factor; + auto dropout_probability = descriptor.dropout_probability; + auto bias_type = descriptor.bias_type; + auto mask_type = descriptor.mask_type; + auto qkv_layout = descriptor.qkv_layout; + auto dtype = descriptor.dtype; + + /* Input tensors */ + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto v_shape = k_shape; + auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; + auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); + + /* Output tensors */ + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 + auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto o_tensor = TensorWrapper(output, o_shape, dtype); + auto q_cu_seqlens_tensor = + TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = + TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); + + /* Prepare RNG state */ + auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + auto backend = + nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), + qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, + num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim); + PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); + + /* Auxiliary tensors (to be propagated to the backward pass later) */ + NVTETensorPack aux_output_tensors; + nvte_tensor_pack_create(&aux_output_tensors); + PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, + softmax_aux); + + /* cuDNN workspace */ + auto workspace_tensor = TensorWrapper(workspace, std::vector{descriptor.wkspace_size}, + descriptor.wkspace_dtype); + + auto dummy_ragged_offset_tensor = + TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); + /* Call the underly NVTE API */ + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + auto qkv = buffers[0]; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), rng_state_tensor.data(), q_max_seqlen, + descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv = buffers[1]; + auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); + nvte_fused_attn_fwd_kvpacked( + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k = buffers[1]; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v = buffers[2]; auto v_shape = k_shape; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - - /* Output tensors */ - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 - auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto o_tensor = TensorWrapper(output, o_shape, dtype); - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - - /* Prepare RNG state */ - auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim); - PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); - - /* Auxiliary tensors (to be propagated to the backward pass later) */ - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, - softmax_aux); - - /* cuDNN workspace */ - auto workspace_tensor = TensorWrapper(workspace, std::vector{descriptor.wkspace_size}, - descriptor.wkspace_dtype); - - auto dummy_ragged_offset_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - /* Call the underly NVTE API */ - if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { - auto qkv = buffers[0]; - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, - descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { - auto q = buffers[0]; - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv = buffers[1]; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, workspace_tensor.data(), stream); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { - auto q = buffers[0]; - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k = buffers[1]; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v = buffers[2]; - auto v_shape = k_shape; - auto v_tensor = TensorWrapper(v, v_shape, dtype); - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); - } - - nvte_tensor_pack_destroy(&aux_output_tensors); + auto v_tensor = TensorWrapper(v, v_shape, dtype); + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, + descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, workspace_tensor.data(), stream); + } else { + NVTE_ERROR("Unsupported qkv_layout."); + } + + nvte_tensor_pack_destroy(&aux_output_tensors); } pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( @@ -398,189 +389,185 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) { + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto v_shape = k_shape; + auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; + + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); + auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); + // F16 doesn't use this tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + + auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); + auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); + + auto q_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); + + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + + TensorWrapper query_workspace_tensor; + auto dummy_ragged_offset_tensor = + TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, query_workspace_tensor.data(), nullptr); + + auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); +} + +void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + const CustomCallFusedAttnDescriptor &descriptor = + *UnpackOpaque(opaque, opaque_len); + + /* Input buffers from XLA */ + /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ + void *bias = buffers[3]; + void *softmax_aux = buffers[4]; + void *rng_state = buffers[5]; + void *output = buffers[6]; + void *doutput = buffers[7]; + void *q_cu_seqlens = buffers[8]; + void *kv_cu_seqlens = buffers[9]; + + /* Output buffer from XLA */ + /* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */ + void *dbias = buffers[13]; + void *workspace = buffers[14]; + + /* Descriptor */ + auto input_batch = descriptor.input_batch; + auto bias_batch = descriptor.bias_batch; + auto q_max_seqlen = descriptor.q_max_seqlen; + auto kv_max_seqlen = descriptor.kv_max_seqlen; + auto attn_heads = descriptor.attn_heads; + auto num_gqa_groups = descriptor.num_gqa_groups; + auto bias_heads = descriptor.bias_heads; + auto head_dim = descriptor.head_dim; + auto scaling_factor = descriptor.scaling_factor; + auto dropout_probability = descriptor.dropout_probability; + auto bias_type = descriptor.bias_type; + auto mask_type = descriptor.mask_type; + auto qkv_layout = descriptor.qkv_layout; + auto dtype = descriptor.dtype; + + /* Input tensors */ + auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; + auto output_tensor = TensorWrapper(output, output_shape, dtype); + auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); + + /* Output tensors */ + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 + auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); + auto q_cu_seqlens_tensor = + TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = + TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); + + /* Auxiliary tensors (propagated from the forward pass) */ + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + auto backend = + nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), + qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, + num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim); + PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, + rng_state, bias); + + /* cuDNN workspace */ + auto wkspace_size = std::vector{descriptor.wkspace_size}; + auto wkspace_dtype = descriptor.wkspace_dtype; + auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); + + auto dummy_ragged_offset_tensor = + TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); + /* Call the underly NVTE API */ + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + auto qkv = buffers[0]; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); + auto dqkv = buffers[10]; + auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); + nvte_fused_attn_bwd_qkvpacked( + qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv = buffers[1]; + auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); + auto dq = buffers[10]; + auto dq_tensor = TensorWrapper(dq, q_shape, dtype); + auto dkv = buffers[11]; + auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); + nvte_fused_attn_bwd_kvpacked( + q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k = buffers[1]; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v = buffers[2]; auto v_shape = k_shape; - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); - auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); - // F16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - - auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); - auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - - auto q_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - - NVTETensorPack aux_input_tensors; - nvte_tensor_pack_create(&aux_input_tensors); - - TensorWrapper query_workspace_tensor; - auto dummy_ragged_offset_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); + auto v_tensor = TensorWrapper(v, v_shape, dtype); + auto dq = buffers[10]; + auto dq_tensor = TensorWrapper(dq, q_shape, dtype); + auto dk = buffers[11]; + auto dk_tensor = TensorWrapper(dk, k_shape, dtype); + auto dv = buffers[12]; + auto dv_tensor = TensorWrapper(dv, v_shape, dtype); nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 - &aux_input_tensors, - dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr); - - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); -} - -void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - /* Input buffers from XLA */ - /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ - void *bias = buffers[3]; - void *softmax_aux = buffers[4]; - void *rng_state = buffers[5]; - void *output = buffers[6]; - void *doutput = buffers[7]; - void *q_cu_seqlens = buffers[8]; - void *kv_cu_seqlens = buffers[9]; - - /* Output buffer from XLA */ - /* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */ - void *dbias = buffers[13]; - void *workspace = buffers[14]; - - /* Descriptor */ - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - auto qkv_layout = descriptor.qkv_layout; - auto dtype = descriptor.dtype; - - /* Input tensors */ - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - auto output_tensor = TensorWrapper(output, output_shape, dtype); - auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); - - /* Output tensors */ - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 - auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - - /* Auxiliary tensors (propagated from the forward pass) */ - NVTETensorPack aux_input_tensors; - nvte_tensor_pack_create(&aux_input_tensors); - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim); - PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, - rng_state, bias); - - /* cuDNN workspace */ - auto wkspace_size = std::vector{descriptor.wkspace_size}; - auto wkspace_dtype = descriptor.wkspace_dtype; - auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - - auto dummy_ragged_offset_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - /* Call the underly NVTE API */ - if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { - auto qkv = buffers[0]; - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - auto dqkv = buffers[10]; - auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); - nvte_fused_attn_bwd_qkvpacked( - qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { - auto q = buffers[0]; - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv = buffers[1]; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - auto dq = buffers[10]; - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv = buffers[11]; - auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { - auto q = buffers[0]; - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k = buffers[1]; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v = buffers[2]; - auto v_shape = k_shape; - auto v_tensor = TensorWrapper(v, v_shape, dtype); - auto dq = buffers[10]; - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk = buffers[11]; - auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv = buffers[12]; - auto dv_tensor = TensorWrapper(dv, v_shape, dtype); - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), - dv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); - } - - nvte_tensor_pack_destroy(&aux_input_tensors); + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else { + NVTE_ERROR("Unsupported qkv_layout."); + } + + nvte_tensor_pack_destroy(&aux_input_tensors); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/misc.cpp b/transformer_engine/jax/csrc/extensions/misc.cpp index a24d4b4996e1b04a43249c8b370197b5d2584143..c40e899e622f7fad8d89326529a658ea91ae3881 100644 --- a/transformer_engine/jax/csrc/extensions/misc.cpp +++ b/transformer_engine/jax/csrc/extensions/misc.cpp @@ -10,7 +10,7 @@ namespace transformer_engine { namespace jax { std::vector MakeShapeVector(NVTEShape shape) { - return std::vector(shape.data, shape.data + shape.ndim); + return std::vector(shape.data, shape.data + shape.ndim); } void Shape::from_vector(const std::vector &shape) { diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 31c8523ddd228a97f41d1a5c6e0efa5cf0e4077f..8063596f3b3e6e7565cdba30b20051f64d40bcde 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -8,7 +8,6 @@ #include "transformer_engine/layer_norm.h" #include "transformer_engine/rmsnorm.h" - namespace transformer_engine { namespace jax { @@ -16,38 +15,38 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd DType in_dtype, DType w_dtype, DType out_dtype, bool is_layer_norm, bool zero_centered_gamma, float eps) { - auto input_shape = std::vector{batch_size, hidden_size}; - auto weight_shape = std::vector{hidden_size}; - auto intermediates_shape = std::vector{batch_size}; - - // empty tensor wrappers are okay just to get workspace size - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype); - auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - - // dummy tensor wrappers that will carry workspace size info later - TensorWrapper dummy_work_tensor, dummy_barrier_tensor; - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - if (is_layer_norm) { - auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); - auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - - layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, - output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr, - num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); - } else { - NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); - nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), - dummy_barrier_tensor.data()); - } - - auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); - auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), - std::make_pair(barrier_shape, dummy_barrier_tensor.dtype())); + auto input_shape = std::vector{batch_size, hidden_size}; + auto weight_shape = std::vector{hidden_size}; + auto intermediates_shape = std::vector{batch_size}; + + // empty tensor wrappers are okay just to get workspace size + auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); + auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype); + auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); + + // dummy tensor wrappers that will carry workspace size info later + TensorWrapper dummy_work_tensor, dummy_barrier_tensor; + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + if (is_layer_norm) { + auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); + auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); + + layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, + output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr, + num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); + } else { + NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); + nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), + rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), + dummy_barrier_tensor.data()); + } + + auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); + auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), + std::make_pair(barrier_shape, dummy_barrier_tensor.dtype())); } void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size, @@ -56,96 +55,95 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac DType out_dtype, void *workspace, DType work_dtype, void *barrier, DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale, float *scale_inv, cudaStream_t stream) { - auto input_shape = std::vector{batch_size, hidden_size}; - auto weight_shape = std::vector{hidden_size}; - auto intermediates_shape = std::vector{batch_size}; - auto workspace_shape = std::vector{workspace_size}; - auto barrier_shape = std::vector{barrier_size}; - auto is_layer_norm = (bias) ? true : false; - - auto input_tensor = TensorWrapper(input, input_shape, in_dtype); - auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype); - - // assume output dtype = input dtype - // If we need mixed I/O precision in the future, we need an additional - // parameter for output type - auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv); - auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - - auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); - auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); - - if (is_layer_norm) { - auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); - auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); - - layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, - output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, - num_sm, workspace_tensor.data(), barrier_tensor.data()); - } else { - NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); - nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(), - barrier_tensor.data()); - } + auto input_shape = std::vector{batch_size, hidden_size}; + auto weight_shape = std::vector{hidden_size}; + auto intermediates_shape = std::vector{batch_size}; + auto workspace_shape = std::vector{workspace_size}; + auto barrier_shape = std::vector{barrier_size}; + auto is_layer_norm = (bias) ? true : false; + + auto input_tensor = TensorWrapper(input, input_shape, in_dtype); + auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype); + + // assume output dtype = input dtype + // If we need mixed I/O precision in the future, we need an additional + // parameter for output type + auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv); + auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); + + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + + auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); + auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); + + if (is_layer_norm) { + auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); + auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); + + layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, + output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, num_sm, + workspace_tensor.data(), barrier_tensor.data()); + } else { + NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); + nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), + rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(), + barrier_tensor.data()); + } } pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, bool zero_centered_gamma, float eps) { - auto input_shape = std::vector{batch_size, hidden_size}; - auto weight_shape = std::vector{hidden_size}; - auto intermediates_shape = std::vector{batch_size}; - auto intermediates_dtype = DType::kFloat32; - - // empty tensor wrappers are okay just to get workspace size - auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); - auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); - auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); - - // dummy tensor wrappers that will carry workspace size info later - TensorWrapper dummy_work_tensor, dummy_barrier_tensor; - TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - - // initialize dBeta information here -- layernorm will modify but RMSnorm will not - std::vector dbeta_part_shape; - if (is_layer_norm) { - auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); - auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); - - layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), - rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), - wgrad_tensor.data(), dbeta_tensor.data(), - dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr, - num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); - - dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape()); - } else { - NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); - nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), - gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), - dummy_dgamma_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), - dummy_barrier_tensor.data()); - - dbeta_part_shape = std::vector{0, 0}; - } - - auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); - auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); - auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), - std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()), - std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()), - std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype())); + auto input_shape = std::vector{batch_size, hidden_size}; + auto weight_shape = std::vector{hidden_size}; + auto intermediates_shape = std::vector{batch_size}; + auto intermediates_dtype = DType::kFloat32; + + // empty tensor wrappers are okay just to get workspace size + auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); + auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); + auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); + + // dummy tensor wrappers that will carry workspace size info later + TensorWrapper dummy_work_tensor, dummy_barrier_tensor; + TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; + + // initialize dBeta information here -- layernorm will modify but RMSnorm will not + std::vector dbeta_part_shape; + if (is_layer_norm) { + auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); + auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); + + layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), + dbeta_tensor.data(), dummy_dgamma_part_tensor.data(), + dummy_dbeta_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), + dummy_barrier_tensor.data()); + + dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape()); + } else { + NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); + nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), + xgrad_tensor.data(), wgrad_tensor.data(), dummy_dgamma_part_tensor.data(), + nullptr, num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); + + dbeta_part_shape = std::vector{0, 0}; + } + + auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); + auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); + auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape()); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), + std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()), + std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()), + std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype())); } void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size, @@ -156,273 +154,269 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part, DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, cudaStream_t stream) { - auto input_shape = std::vector{batch_size, hidden_size}; - auto weight_shape = std::vector{hidden_size}; - auto intermediates_shape = std::vector{batch_size}; - auto intermediates_dtype = DType::kFloat32; - auto is_layer_norm = (dbeta) ? true : false; - - // assume input type = output type - auto *grad_output = ograd; - auto x_dtype = in_dtype; - auto dz_tensor = TensorWrapper(grad_output, input_shape, x_dtype); - - auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype); - - auto *x = input; - auto x_tensor = TensorWrapper(x, input_shape, x_dtype); - - auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype); - auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype); - auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - - auto workspace_shape = std::vector{wkspace_size}; - auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); - auto barrier_shape = std::vector{barrier_size}; - auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); - auto dgamma_part_tensor = - TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype); - - if (is_layer_norm) { - auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); - auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); - auto dbeta_part_tensor = - TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype); - - layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), - rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), - wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(), - dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(), - barrier_tensor.data()); - } else { - NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); - nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), - gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), - dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(), - barrier_tensor.data()); - } + auto input_shape = std::vector{batch_size, hidden_size}; + auto weight_shape = std::vector{hidden_size}; + auto intermediates_shape = std::vector{batch_size}; + auto intermediates_dtype = DType::kFloat32; + auto is_layer_norm = (dbeta) ? true : false; + + // assume input type = output type + auto *grad_output = ograd; + auto x_dtype = in_dtype; + auto dz_tensor = TensorWrapper(grad_output, input_shape, x_dtype); + + auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype); + + auto *x = input; + auto x_tensor = TensorWrapper(x, input_shape, x_dtype); + + auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype); + auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype); + auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); + + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; + + auto workspace_shape = std::vector{wkspace_size}; + auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); + auto barrier_shape = std::vector{barrier_size}; + auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); + auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype); + + if (is_layer_norm) { + auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); + auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); + auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype); + + layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), + dbeta_tensor.data(), dgamma_part_tensor.data(), dbeta_part_tensor.data(), + stream, num_sm, workspace_tensor.data(), barrier_tensor.data()); + } else { + NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); + nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), + xgrad_tensor.data(), wgrad_tensor.data(), dgamma_part_tensor.data(), stream, + num_sm, workspace_tensor.data(), barrier_tensor.data()); + } } void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *weight = buffers[1]; - auto *bias = buffers[2]; - auto *amax = reinterpret_cast(buffers[3]); - auto *scale = reinterpret_cast(buffers[4]); - auto *scale_inv = reinterpret_cast(buffers[5]); - auto *output = buffers[6]; - auto *mu = buffers[7]; - auto *rsigma = buffers[8]; - auto *amax_out = buffers[9]; - auto *workspace = buffers[10]; - auto *barrier = buffers[11]; - assert(amax_out == amax); - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto batch_size = desc.batch_size; - auto hidden_size = desc.hidden_size; - auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto in_dtype = desc.x_dtype; - auto w_dtype = desc.w_dtype; - auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto eps = desc.eps; - auto zero_centered_gamma = desc.zero_centered_gamma; - auto sm_margin = desc.sm_margin; - - auto out_dtype = DType::kFloat8E4M3; - - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + auto *input = buffers[0]; + auto *weight = buffers[1]; + auto *bias = buffers[2]; + auto *amax = reinterpret_cast(buffers[3]); + auto *scale = reinterpret_cast(buffers[4]); + auto *scale_inv = reinterpret_cast(buffers[5]); + auto *output = buffers[6]; + auto *mu = buffers[7]; + auto *rsigma = buffers[8]; + auto *amax_out = buffers[9]; + auto *workspace = buffers[10]; + auto *barrier = buffers[11]; + assert(amax_out == amax); + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; + auto in_dtype = desc.x_dtype; + auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; + auto eps = desc.eps; + auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; + + auto out_dtype = DType::kFloat8E4M3; + + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, + eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, + wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, + stream); } void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *weight = buffers[1]; - auto *bias = buffers[2]; - auto *output = buffers[3]; - auto *mu = buffers[4]; - auto *rsigma = buffers[5]; - auto *workspace = buffers[6]; - auto *barrier = buffers[7]; - - float *amax = nullptr; - float *scale = nullptr; - float *scale_inv = nullptr; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto batch_size = desc.batch_size; - auto hidden_size = desc.hidden_size; - auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto in_dtype = desc.x_dtype; - auto w_dtype = desc.w_dtype; - auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto eps = desc.eps; - auto out_dtype = in_dtype; - auto zero_centered_gamma = desc.zero_centered_gamma; - auto sm_margin = desc.sm_margin; - - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + auto *input = buffers[0]; + auto *weight = buffers[1]; + auto *bias = buffers[2]; + auto *output = buffers[3]; + auto *mu = buffers[4]; + auto *rsigma = buffers[5]; + auto *workspace = buffers[6]; + auto *barrier = buffers[7]; + + float *amax = nullptr; + float *scale = nullptr; + float *scale_inv = nullptr; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; + auto in_dtype = desc.x_dtype; + auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; + auto eps = desc.eps; + auto out_dtype = in_dtype; + auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; + + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, + eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, + wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, + stream); } void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - const auto &desc = *UnpackOpaque(opaque, opaque_len); - - auto batch_size = desc.batch_size; - auto hidden_size = desc.hidden_size; - auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto dgamma_part_shape = desc.dgamma_part_shape; - auto dbeta_part_shape = desc.dbeta_part_shape; - auto in_dtype = desc.x_dtype; - auto w_dtype = desc.w_dtype; - auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto dgamma_part_dtype = desc.dgamma_part_dtype; - auto dbeta_part_dtype = desc.dbeta_part_dtype; - auto eps = desc.eps; - auto zero_centered_gamma = desc.zero_centered_gamma; - auto sm_margin = desc.sm_margin; - - auto *ograd = buffers[0]; - auto *mu = buffers[1]; - auto *rsigma = buffers[2]; - auto *input = buffers[3]; - auto *weight = buffers[4]; - auto *xgrad = buffers[5]; - auto *wgrad = buffers[6]; - auto *dbeta = buffers[7]; - auto *workspace = buffers[8]; - auto *barrier = buffers[9]; - auto *dgamma_part = buffers[10]; - auto *dbeta_part = buffers[11]; - - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, stream); + const auto &desc = *UnpackOpaque(opaque, opaque_len); + + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; + auto dgamma_part_shape = desc.dgamma_part_shape; + auto dbeta_part_shape = desc.dbeta_part_shape; + auto in_dtype = desc.x_dtype; + auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; + auto dgamma_part_dtype = desc.dgamma_part_dtype; + auto dbeta_part_dtype = desc.dbeta_part_dtype; + auto eps = desc.eps; + auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; + + auto *ograd = buffers[0]; + auto *mu = buffers[1]; + auto *rsigma = buffers[2]; + auto *input = buffers[3]; + auto *weight = buffers[4]; + auto *xgrad = buffers[5]; + auto *wgrad = buffers[6]; + auto *dbeta = buffers[7]; + auto *workspace = buffers[8]; + auto *barrier = buffers[9]; + auto *dgamma_part = buffers[10]; + auto *dbeta_part = buffers[11]; + + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, + dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, + w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, + rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, + dbeta_part_dtype, stream); } void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *weight = buffers[1]; - auto *amax = reinterpret_cast(buffers[2]); - auto *scale = reinterpret_cast(buffers[3]); - auto *scale_inv = reinterpret_cast(buffers[4]); - auto *output = buffers[5]; - auto *rsigma = buffers[6]; - auto *amax_out = buffers[7]; - auto *workspace = buffers[8]; - auto *barrier = buffers[9]; - assert(amax_out == amax); - - void *bias = nullptr; - void *mu = nullptr; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto batch_size = desc.batch_size; - auto hidden_size = desc.hidden_size; - auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto in_dtype = desc.x_dtype; - auto w_dtype = desc.w_dtype; - auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto eps = desc.eps; - auto zero_centered_gamma = desc.zero_centered_gamma; - auto sm_margin = desc.sm_margin; - auto out_dtype = DType::kFloat8E4M3; - - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + auto *input = buffers[0]; + auto *weight = buffers[1]; + auto *amax = reinterpret_cast(buffers[2]); + auto *scale = reinterpret_cast(buffers[3]); + auto *scale_inv = reinterpret_cast(buffers[4]); + auto *output = buffers[5]; + auto *rsigma = buffers[6]; + auto *amax_out = buffers[7]; + auto *workspace = buffers[8]; + auto *barrier = buffers[9]; + assert(amax_out == amax); + + void *bias = nullptr; + void *mu = nullptr; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; + auto in_dtype = desc.x_dtype; + auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; + auto eps = desc.eps; + auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; + auto out_dtype = DType::kFloat8E4M3; + + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, + eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, + wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, + stream); } void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *weight = buffers[1]; - auto *output = buffers[2]; - auto *rsigma = buffers[3]; - auto *workspace = buffers[4]; - auto *barrier = buffers[5]; - - void *bias = nullptr; - void *mu = nullptr; - float *amax = nullptr; - float *scale = nullptr; - float *scale_inv = nullptr; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto batch_size = desc.batch_size; - auto hidden_size = desc.hidden_size; - auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto in_dtype = desc.x_dtype; - auto w_dtype = desc.w_dtype; - auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto eps = desc.eps; - auto zero_centered_gamma = desc.zero_centered_gamma; - auto sm_margin = desc.sm_margin; - auto out_dtype = in_dtype; - - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + auto *input = buffers[0]; + auto *weight = buffers[1]; + auto *output = buffers[2]; + auto *rsigma = buffers[3]; + auto *workspace = buffers[4]; + auto *barrier = buffers[5]; + + void *bias = nullptr; + void *mu = nullptr; + float *amax = nullptr; + float *scale = nullptr; + float *scale_inv = nullptr; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; + auto in_dtype = desc.x_dtype; + auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; + auto eps = desc.eps; + auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; + auto out_dtype = in_dtype; + + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, + eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, + wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, + stream); } void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *ograd = buffers[0]; - auto *rsigma = buffers[1]; - auto *input = buffers[2]; - auto *weight = buffers[3]; - auto *xgrad = buffers[4]; - auto *wgrad = buffers[5]; - auto *workspace = buffers[6]; - auto *barrier = buffers[7]; - auto *dgamma_part = buffers[8]; - - void *mu = nullptr; - void *dbeta = nullptr; - void *dbeta_part = nullptr; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto batch_size = desc.batch_size; - auto hidden_size = desc.hidden_size; - auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto dgamma_part_shape = desc.dgamma_part_shape; - Shape dbeta_part_shape; - dbeta_part_shape.from_vector({0, 0}); - auto in_dtype = desc.x_dtype; - auto w_dtype = desc.w_dtype; - auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto dgamma_part_dtype = desc.dgamma_part_dtype; - auto dbeta_part_dtype = DType::kByte; - auto eps = desc.eps; - auto zero_centered_gamma = desc.zero_centered_gamma; - - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, stream); + auto *ograd = buffers[0]; + auto *rsigma = buffers[1]; + auto *input = buffers[2]; + auto *weight = buffers[3]; + auto *xgrad = buffers[4]; + auto *wgrad = buffers[5]; + auto *workspace = buffers[6]; + auto *barrier = buffers[7]; + auto *dgamma_part = buffers[8]; + + void *mu = nullptr; + void *dbeta = nullptr; + void *dbeta_part = nullptr; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; + auto dgamma_part_shape = desc.dgamma_part_shape; + Shape dbeta_part_shape; + dbeta_part_shape.from_vector({0, 0}); + auto in_dtype = desc.x_dtype; + auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; + auto dgamma_part_dtype = desc.dgamma_part_dtype; + auto dbeta_part_dtype = DType::kByte; + auto eps = desc.eps; + auto zero_centered_gamma = desc.zero_centered_gamma; + + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, + dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, + w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, + rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, + dbeta_part_dtype, stream); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index d06c8d7dddfc94c6c8938e13df1e4a03fb5728e2..a79e02c3840894657e524cc19a00d3fa9f373832 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -6,32 +6,30 @@ #include "jax/csrc/extensions.h" - namespace transformer_engine { namespace jax { pybind11::bytes PackCustomCallCommonDescriptor(const std::vector &shape, DType in_dtype, DType out_dtype, size_t act_enum) { - CustomCallCommonDescriptor desc{}; - desc.shape.from_vector(shape); - desc.in_dtype = in_dtype; - desc.out_dtype = out_dtype; - desc.act_enum = act_enum; - return PackOpaque(desc); + CustomCallCommonDescriptor desc{}; + desc.shape.from_vector(shape); + desc.in_dtype = in_dtype; + desc.out_dtype = out_dtype; + desc.act_enum = act_enum; + return PackOpaque(desc); } pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector &shape, const std::vector &wkshape, DType in_dtype, - DType out_dtype, DType wk_dtype, - size_t act_enum) { - CustomCallCommonWkDescriptor desc{}; - desc.shape.from_vector(shape); - desc.wkshape.from_vector(wkshape); - desc.in_dtype = in_dtype; - desc.out_dtype = out_dtype; - desc.wk_dtype = wk_dtype; - desc.act_enum = act_enum; - return PackOpaque(desc); + DType out_dtype, DType wk_dtype, size_t act_enum) { + CustomCallCommonWkDescriptor desc{}; + desc.shape.from_vector(shape); + desc.wkshape.from_vector(wkshape); + desc.in_dtype = in_dtype; + desc.out_dtype = out_dtype; + desc.wk_dtype = wk_dtype; + desc.act_enum = act_enum; + return PackOpaque(desc); } pybind11::bytes PackCustomCallNormDescriptor( @@ -39,30 +37,30 @@ pybind11::bytes PackCustomCallNormDescriptor( const std::vector &dgamma_part_shape, const std::vector &dbeta_part_shape, DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype, DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) { - CustomCallNormDescriptor desc{}; - desc.batch_size = batch_size; - desc.hidden_size = hidden_size; - desc.wkspace_size = wkspace_size; - desc.barrier_size = barrier_size; - desc.dgamma_part_shape.from_vector(dgamma_part_shape); - desc.dbeta_part_shape.from_vector(dbeta_part_shape); - desc.x_dtype = x_dtype; - desc.w_dtype = w_dtype; - desc.wkspace_dtype = wkspace_dtype; - desc.barrier_dtype = barrier_dtype; - desc.dgamma_part_dtype = dgamma_part_dtype; - desc.dbeta_part_dtype = dbeta_part_dtype; - desc.zero_centered_gamma = zero_centered_gamma; - desc.eps = eps; - desc.sm_margin = sm_margin; - return PackOpaque(desc); + CustomCallNormDescriptor desc{}; + desc.batch_size = batch_size; + desc.hidden_size = hidden_size; + desc.wkspace_size = wkspace_size; + desc.barrier_size = barrier_size; + desc.dgamma_part_shape.from_vector(dgamma_part_shape); + desc.dbeta_part_shape.from_vector(dbeta_part_shape); + desc.x_dtype = x_dtype; + desc.w_dtype = w_dtype; + desc.wkspace_dtype = wkspace_dtype; + desc.barrier_dtype = barrier_dtype; + desc.dgamma_part_dtype = dgamma_part_dtype; + desc.dbeta_part_dtype = dbeta_part_dtype; + desc.zero_centered_gamma = zero_centered_gamma; + desc.eps = eps; + desc.sm_margin = sm_margin; + return PackOpaque(desc); } pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size, size_t head_dim, size_t q_seqlen, size_t k_seqlen, DType dtype, float scale_factor) { - return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen, - dtype, scale_factor}); + return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen, dtype, + scale_factor}); } pybind11::bytes PackCustomCallFusedAttnDescriptor( @@ -71,11 +69,11 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) { - return PackOpaque(CustomCallFusedAttnDescriptor{ - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, - bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type, - mask_type, qkv_layout, dtype, wkspace_dtype, is_training}); + return PackOpaque(CustomCallFusedAttnDescriptor{ + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, + head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, + dtype, wkspace_dtype, is_training}); } -} // namespace jax -} // namespace transformer_engine +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 14dff0776397b7f8e91e939ab892a5d2149f71b4..f621572bd4242164abe715f80b4d11a994686d31 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -11,106 +11,106 @@ namespace jax { template pybind11::capsule EncapsulateFunction(T *fn) { - return pybind11::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); + return pybind11::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); } pybind11::dict Registrations() { - pybind11::dict dict; - dict["te_transpose"] = EncapsulateFunction(Transpose); - dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose); + pybind11::dict dict; + dict["te_transpose"] = EncapsulateFunction(Transpose); + dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose); - dict["te_act_lu"] = EncapsulateFunction(ActLu); - dict["te_act_lu_fp8"] = EncapsulateFunction(ActLuFP8); - dict["te_dact_lu"] = EncapsulateFunction(DActLu); - dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose); - dict["te_dact_lu_dbias_cast_transpose"] = EncapsulateFunction(DActLuDBiasCastTranspose); - dict["te_dgated_act_lu_cast_transpose"] = EncapsulateFunction(DGatedActLuCastTranspose); + dict["te_act_lu"] = EncapsulateFunction(ActLu); + dict["te_act_lu_fp8"] = EncapsulateFunction(ActLuFP8); + dict["te_dact_lu"] = EncapsulateFunction(DActLu); + dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose); + dict["te_dact_lu_dbias_cast_transpose"] = EncapsulateFunction(DActLuDBiasCastTranspose); + dict["te_dgated_act_lu_cast_transpose"] = EncapsulateFunction(DGatedActLuCastTranspose); - dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward); - dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8); - dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward); - dict["te_rmsnorm_forward"] = EncapsulateFunction(RMSNormForward); - dict["te_rmsnorm_forward_fp8"] = EncapsulateFunction(RMSNormForwardFP8); - dict["te_rmsnorm_backward"] = EncapsulateFunction(RMSNormBackward); - dict["te_quantize"] = EncapsulateFunction(Quantize); - dict["te_dequantize"] = EncapsulateFunction(Dequantize); - dict["te_scaled_softmax_forward"] = EncapsulateFunction(ScaledSoftmaxForward); - dict["te_scaled_softmax_backward"] = EncapsulateFunction(ScaledSoftmaxBackward); - dict["te_scaled_masked_softmax_forward"] = EncapsulateFunction(ScaledMaskedSoftmaxForward); - dict["te_scaled_masked_softmax_backward"] = EncapsulateFunction(ScaledMaskedSoftmaxBackward); - dict["te_scaled_upper_triang_masked_softmax_forward"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); - dict["te_scaled_upper_triang_masked_softmax_backward"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); - dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); - dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); - return dict; + dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward); + dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8); + dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward); + dict["te_rmsnorm_forward"] = EncapsulateFunction(RMSNormForward); + dict["te_rmsnorm_forward_fp8"] = EncapsulateFunction(RMSNormForwardFP8); + dict["te_rmsnorm_backward"] = EncapsulateFunction(RMSNormBackward); + dict["te_quantize"] = EncapsulateFunction(Quantize); + dict["te_dequantize"] = EncapsulateFunction(Dequantize); + dict["te_scaled_softmax_forward"] = EncapsulateFunction(ScaledSoftmaxForward); + dict["te_scaled_softmax_backward"] = EncapsulateFunction(ScaledSoftmaxBackward); + dict["te_scaled_masked_softmax_forward"] = EncapsulateFunction(ScaledMaskedSoftmaxForward); + dict["te_scaled_masked_softmax_backward"] = EncapsulateFunction(ScaledMaskedSoftmaxBackward); + dict["te_scaled_upper_triang_masked_softmax_forward"] = + EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); + dict["te_scaled_upper_triang_masked_softmax_backward"] = + EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); + dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); + dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); + return dict; } PYBIND11_MODULE(transformer_engine_jax, m) { - m.def("registrations", &Registrations); - m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, - pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0); - m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor, - pybind11::arg(), pybind11::arg(), pybind11::arg(), - pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0); - m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); - m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); - m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); - m.def("get_fused_attn_backend", &GetFusedAttnBackend); - m.def("get_cuda_version", &GetCudaRuntimeVersion); - m.def("get_device_compute_capability", &GetDeviceComputeCapability); - m.def("get_cublasLt_version", &cublasLtGetVersion); - m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); - m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); - m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); - m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); - m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); - m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); + m.def("registrations", &Registrations); + m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, pybind11::arg(), pybind11::arg(), + pybind11::arg(), pybind11::arg("act_num") = 0); + m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor, pybind11::arg(), + pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg(), + pybind11::arg("act_num") = 0); + m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); + m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); + m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); + m.def("get_fused_attn_backend", &GetFusedAttnBackend); + m.def("get_cuda_version", &GetCudaRuntimeVersion); + m.def("get_device_compute_capability", &GetDeviceComputeCapability); + m.def("get_cublasLt_version", &cublasLtGetVersion); + m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); + m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); + m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); + m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); + m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); + m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); - pybind11::enum_(m, "DType", pybind11::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kInt64", DType::kInt64) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); + pybind11::enum_(m, "DType", pybind11::module_local()) + .value("kByte", DType::kByte) + .value("kInt32", DType::kInt32) + .value("kInt64", DType::kInt64) + .value("kFloat32", DType::kFloat32) + .value("kFloat16", DType::kFloat16) + .value("kBFloat16", DType::kBFloat16) + .value("kFloat8E4M3", DType::kFloat8E4M3) + .value("kFloat8E5M2", DType::kFloat8E5M2); - pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK); + pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK); - pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); - pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) - .value("GELU", NVTE_Activation_Type::GELU) - .value("GEGLU", NVTE_Activation_Type::GEGLU) - .value("SILU", NVTE_Activation_Type::SILU) - .value("SWIGLU", NVTE_Activation_Type::SWIGLU) - .value("RELU", NVTE_Activation_Type::RELU) - .value("REGLU", NVTE_Activation_Type::REGLU) - .value("QGELU", NVTE_Activation_Type::QGELU) - .value("QGEGLU", NVTE_Activation_Type::QGEGLU) - .value("SRELU", NVTE_Activation_Type::SRELU) - .value("SREGLU", NVTE_Activation_Type::SREGLU); + pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) + .value("GELU", NVTE_Activation_Type::GELU) + .value("GEGLU", NVTE_Activation_Type::GEGLU) + .value("SILU", NVTE_Activation_Type::SILU) + .value("SWIGLU", NVTE_Activation_Type::SWIGLU) + .value("RELU", NVTE_Activation_Type::RELU) + .value("REGLU", NVTE_Activation_Type::REGLU) + .value("QGELU", NVTE_Activation_Type::QGELU) + .value("QGEGLU", NVTE_Activation_Type::QGEGLU) + .value("SRELU", NVTE_Activation_Type::SRELU) + .value("SREGLU", NVTE_Activation_Type::SREGLU); - pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); + pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 3da30ad116da887977bb1edce61fb9c30ba92d72..0056a630cdb216f134b6985617f2540c0f40bfbe 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -11,37 +11,37 @@ namespace transformer_engine { namespace jax { void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *amax = reinterpret_cast(buffers[1]); - auto *scale = reinterpret_cast(buffers[2]); - auto *scale_inv = reinterpret_cast(buffers[3]); - auto *output = buffers[4]; - auto *amax_out = reinterpret_cast(buffers[5]); - assert(amax == amax_out); - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto shape = desc.shape.to_vector(); - auto input_tensor = TensorWrapper(input, shape, desc.in_dtype); - auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv); - - nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); + auto *input = buffers[0]; + auto *amax = reinterpret_cast(buffers[1]); + auto *scale = reinterpret_cast(buffers[2]); + auto *scale_inv = reinterpret_cast(buffers[3]); + auto *output = buffers[4]; + auto *amax_out = reinterpret_cast(buffers[5]); + assert(amax == amax_out); + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto shape = desc.shape.to_vector(); + auto input_tensor = TensorWrapper(input, shape, desc.in_dtype); + auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv); + + nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); } void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *amax = reinterpret_cast(buffers[1]); - auto *scale = reinterpret_cast(buffers[2]); - auto *scale_inv = reinterpret_cast(buffers[3]); - auto *output = buffers[4]; + auto *input = buffers[0]; + auto *amax = reinterpret_cast(buffers[1]); + auto *scale = reinterpret_cast(buffers[2]); + auto *scale_inv = reinterpret_cast(buffers[3]); + auto *output = buffers[4]; - const auto &desc = *UnpackOpaque(opaque, opaque_len); + const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto shape = desc.shape.to_vector(); - auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); + auto shape = desc.shape.to_vector(); + auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); - auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); + auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); - nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); + nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/softmax.cpp b/transformer_engine/jax/csrc/extensions/softmax.cpp index e6c01b2acca45062bf27f873c9ef84c59306bda2..18d59667a9d29b4856019b6102028ac5ba04d203 100644 --- a/transformer_engine/jax/csrc/extensions/softmax.cpp +++ b/transformer_engine/jax/csrc/extensions/softmax.cpp @@ -4,112 +4,109 @@ * See LICENSE for license information. ************************************************************************/ -#include "jax/csrc/extensions.h" #include "transformer_engine/softmax.h" +#include "jax/csrc/extensions.h" namespace transformer_engine { namespace jax { void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *output = buffers[1]; + auto *input = buffers[0]; + auto *output = buffers[1]; - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto shape = std::vector{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen}; - auto dtype = desc.dtype; + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto shape = std::vector{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen}; + auto dtype = desc.dtype; - auto input_tensor = TensorWrapper(input, shape, dtype); - auto output_tensor = TensorWrapper(output, shape, dtype); + auto input_tensor = TensorWrapper(input, shape, dtype); + auto output_tensor = TensorWrapper(output, shape, dtype); - nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor, - stream); + nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor, stream); } void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *grad_output = buffers[0]; - auto *softmax_output = buffers[1]; - auto *dgrad = buffers[2]; + auto *grad_output = buffers[0]; + auto *softmax_output = buffers[1]; + auto *dgrad = buffers[2]; - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto shape = std::vector{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen}; - auto dtype = desc.dtype; + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto shape = std::vector{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen}; + auto dtype = desc.dtype; - auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); - auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); - auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype); + auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); + auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); + auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype); - nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(), - dgrad_tensor.data(), desc.scale_factor, stream); + nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(), + dgrad_tensor.data(), desc.scale_factor, stream); } void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *mask = buffers[1]; - auto *output = buffers[2]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto io_shape = - std::vector{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen}; - auto mask_shape = std::vector{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen}; - auto dtype = desc.dtype; - - auto input_tensor = TensorWrapper(input, io_shape, dtype); - // Mask would be casted to uint8_t - auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte); - auto output_tensor = TensorWrapper(output, io_shape, dtype); - - nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), - output_tensor.data(), desc.scale_factor, stream); + auto *input = buffers[0]; + auto *mask = buffers[1]; + auto *output = buffers[2]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto io_shape = std::vector{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen}; + auto mask_shape = std::vector{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen}; + auto dtype = desc.dtype; + + auto input_tensor = TensorWrapper(input, io_shape, dtype); + // Mask would be casted to uint8_t + auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte); + auto output_tensor = TensorWrapper(output, io_shape, dtype); + + nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(), + desc.scale_factor, stream); } void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - // The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax. - ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len); + // The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax. + ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len); } void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *output = buffers[1]; + auto *input = buffers[0]; + auto *output = buffers[1]; - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto attn_batch = desc.batch_size * desc.head_dim; - auto shape = std::vector{attn_batch, desc.q_seqlen, desc.k_seqlen}; - auto dtype = desc.dtype; + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto attn_batch = desc.batch_size * desc.head_dim; + auto shape = std::vector{attn_batch, desc.q_seqlen, desc.k_seqlen}; + auto dtype = desc.dtype; - auto input_tensor = TensorWrapper(input, shape, dtype); + auto input_tensor = TensorWrapper(input, shape, dtype); - auto output_tensor = TensorWrapper(output, shape, dtype); + auto output_tensor = TensorWrapper(output, shape, dtype); - nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(), - desc.scale_factor, stream); + nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(), + desc.scale_factor, stream); } void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *grad_output = buffers[0]; - auto *softmax_output = buffers[1]; - auto *dgrad = buffers[2]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto attn_batch = desc.batch_size * desc.head_dim; - auto shape = std::vector{attn_batch, desc.q_seqlen, desc.k_seqlen}; - auto dtype = desc.dtype; - - auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); - auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); - auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype); - - nvte_scaled_upper_triang_masked_softmax_backward( - grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(), - desc.scale_factor, stream); + auto *grad_output = buffers[0]; + auto *softmax_output = buffers[1]; + auto *dgrad = buffers[2]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto attn_batch = desc.batch_size * desc.head_dim; + auto shape = std::vector{attn_batch, desc.q_seqlen, desc.k_seqlen}; + auto dtype = desc.dtype; + + auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); + auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); + auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype); + + nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(), + softmax_output_tensor.data(), + dgrad_tensor.data(), desc.scale_factor, stream); } } // namespace jax } // namespace transformer_engine - diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index e8bec2f7eb890eb15dca75b64c94cf399d8c402f..88861a80d4225a202da72c408a92ffeb11aa02c2 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -4,127 +4,126 @@ * See LICENSE for license information. ************************************************************************/ -#include "jax/csrc/extensions.h" #include "transformer_engine/transpose.h" +#include "jax/csrc/extensions.h" + namespace transformer_engine { namespace jax { void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, void *output) { - auto input_shape = std::vector{rows, cols}; - auto output_shape = std::vector{cols, rows}; + auto input_shape = std::vector{rows, cols}; + auto output_shape = std::vector{cols, rows}; - auto input_tensor = TensorWrapper(input, input_shape, dtype); - auto transposed_tensor = TensorWrapper(output, output_shape, dtype); + auto input_tensor = TensorWrapper(input, input_shape, dtype); + auto transposed_tensor = TensorWrapper(output, output_shape, dtype); - nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream); + nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream); } void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - void *input = buffers[0]; - void *output = buffers[1]; + void *input = buffers[0]; + void *output = buffers[1]; - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto rows = desc.shape.dims[0]; - auto cols = desc.shape.dims[1]; - assert(desc.in_dtype == desc.out_dtype); - auto dtype = desc.out_dtype; + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto rows = desc.shape.dims[0]; + auto cols = desc.shape.dims[1]; + assert(desc.in_dtype == desc.out_dtype); + auto dtype = desc.out_dtype; - TransposeImpl(input, rows, cols, dtype, stream, output); + TransposeImpl(input, rows, cols, dtype, stream, output); } void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - float *amax = reinterpret_cast(buffers[1]); - float *scale = reinterpret_cast(buffers[2]); - float *scale_inv = reinterpret_cast(buffers[3]); - auto *input_cast = buffers[4]; - auto *input_cast_trans = buffers[5]; - float *amax_out = reinterpret_cast(buffers[6]); - assert(amax == amax_out); - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; - } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto input_shape = std::vector{m, n}; - auto input_trans_shape = std::vector{n, m}; - - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto input_cast_tensor = - TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape, - desc.out_dtype, amax_out, scale, scale_inv); - - nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), - input_cast_trans_tensor.data(), stream); + auto *input = buffers[0]; + float *amax = reinterpret_cast(buffers[1]); + float *scale = reinterpret_cast(buffers[2]); + float *scale_inv = reinterpret_cast(buffers[3]); + auto *input_cast = buffers[4]; + auto *input_cast_trans = buffers[5]; + float *amax_out = reinterpret_cast(buffers[6]); + assert(amax == amax_out); + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto input_shape = std::vector{m, n}; + auto input_trans_shape = std::vector{n, m}; + + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto input_cast_tensor = + TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape, desc.out_dtype, + amax_out, scale, scale_inv); + + nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(), + stream); } pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype) { - auto input_shape = std::vector{batch_size, hidden_size}; - auto output_shape = std::vector{batch_size, hidden_size}; - auto output_trans_shape = std::vector{hidden_size, batch_size}; - auto dbias_shape = std::vector{hidden_size}; + DType in_dtype, DType out_dtype) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto output_shape = std::vector{batch_size, hidden_size}; + auto output_trans_shape = std::vector{hidden_size, batch_size}; + auto dbias_shape = std::vector{hidden_size}; - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); - auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); - auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); + auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); + auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); - TensorWrapper dummy_workspace; + TensorWrapper dummy_workspace; - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), dbias_tensor.data(), - dummy_workspace.data(), nullptr); + nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), dummy_workspace.data(), nullptr); - auto work_shape = MakeShapeVector(dummy_workspace.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); + auto work_shape = MakeShapeVector(dummy_workspace.shape()); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); } void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - auto *input = buffers[0]; - float *amax = reinterpret_cast(buffers[1]); - float *scale = reinterpret_cast(buffers[2]); - float *scale_inv = reinterpret_cast(buffers[3]); - auto *output = buffers[4]; - auto *output_trans = buffers[5]; - auto *dbias = buffers[6]; - float *amax_out = reinterpret_cast(buffers[7]); - void *workspace_ptr = buffers[8]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - assert(amax == amax_out); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; - } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto input_shape = std::vector{m, n}; - auto output_shape = std::vector{m, n}; - auto output_trans_shape = std::vector{n, m}; - auto dbias_shape = std::vector{n}; - - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto output_tensor = - TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); - - auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); - - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), dbias_tensor.data(), - workspace.data(), stream); + size_t opaque_len) { + auto *input = buffers[0]; + float *amax = reinterpret_cast(buffers[1]); + float *scale = reinterpret_cast(buffers[2]); + float *scale_inv = reinterpret_cast(buffers[3]); + auto *output = buffers[4]; + auto *output_trans = buffers[5]; + auto *dbias = buffers[6]; + float *amax_out = reinterpret_cast(buffers[7]); + void *workspace_ptr = buffers[8]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + assert(amax == amax_out); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto input_shape = std::vector{m, n}; + auto output_shape = std::vector{m, n}; + auto output_trans_shape = std::vector{n, m}; + auto dbias_shape = std::vector{n}; + + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto output_tensor = + TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); + + auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); + + nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); } } // namespace jax diff --git a/transformer_engine/jax/csrc/utils.cu b/transformer_engine/jax/csrc/utils.cu index 5d3d9621c45e8874e19f482450109bfb2a600170..a8a3d7557e7ec07e0918b50af2dfdc582d86d618 100644 --- a/transformer_engine/jax/csrc/utils.cu +++ b/transformer_engine/jax/csrc/utils.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ #include + #include #include "common/util/cuda_runtime.h" @@ -13,35 +14,35 @@ namespace transformer_engine { namespace jax { int GetCudaRuntimeVersion() { - int ver = 0; - NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver)); - return ver; + int ver = 0; + NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver)); + return ver; } int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); } __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed, int64_t offset) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid > 0) return; - rng_state_dst[0] = seed[0]; - rng_state_dst[1] = offset; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid > 0) return; + rng_state_dst[0] = seed[0]; + rng_state_dst[1] = offset; } void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, cudaStream_t stream) { - size_t increment = 0; - if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { - increment = 16; - } else { - constexpr int threads_per_cta = 128; - increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta; - } - auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment); - populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(rng_state_dst), - reinterpret_cast(seed), offset); - NVTE_CHECK_CUDA(cudaGetLastError()); + size_t increment = 0; + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + increment = 16; + } else { + constexpr int threads_per_cta = 128; + increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta; + } + auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment); + populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(rng_state_dst), + reinterpret_cast(seed), offset); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace jax diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 01e10de9849021e6a5216b182242ce9caaed3b76..640b6daba1ce4b6b1d8d4f1899a4f6ba13b2e2d2 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -7,16 +7,16 @@ #ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_ #define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_ +#include +#include + #include #include #include #include #include -#include - #include "common/util/logging.h" -#include namespace transformer_engine { namespace jax { @@ -30,55 +30,55 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q class cudaDevicePropertiesManager { public: - static cudaDevicePropertiesManager &Instance() { - static thread_local cudaDevicePropertiesManager instance; - return instance; + static cudaDevicePropertiesManager &Instance() { + static thread_local cudaDevicePropertiesManager instance; + return instance; + } + + int GetMultiProcessorCount() { + if (!prop_queried_) { + int device_id; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + cudaGetDeviceProperties(&prop_, device_id); + prop_queried_ = true; } - - int GetMultiProcessorCount() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.multiProcessorCount; - } - - int GetMajor() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.major; + return prop_.multiProcessorCount; + } + + int GetMajor() { + if (!prop_queried_) { + int device_id; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + cudaGetDeviceProperties(&prop_, device_id); + prop_queried_ = true; } + return prop_.major; + } private: - bool prop_queried_ = false; - cudaDeviceProp prop_; + bool prop_queried_ = false; + cudaDeviceProp prop_; }; class FusedAttnOffsetManager { public: - static FusedAttnOffsetManager &Instance() { - static thread_local FusedAttnOffsetManager instance; - return instance; - } + static FusedAttnOffsetManager &Instance() { + static thread_local FusedAttnOffsetManager instance; + return instance; + } - size_t GetAndUpdateOffset(size_t increment) { - size_t ret = offset_; - offset_ += increment; - return ret; - } + size_t GetAndUpdateOffset(size_t increment) { + size_t ret = offset_; + offset_ += increment; + return ret; + } - FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete; - void operator=(FusedAttnOffsetManager const &) = delete; + FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete; + void operator=(FusedAttnOffsetManager const &) = delete; private: - FusedAttnOffsetManager() {} - size_t offset_ = 0; + FusedAttnOffsetManager() {} + size_t offset_ = 0; }; } // namespace jax diff --git a/transformer_engine/jax/dot.py b/transformer_engine/jax/dot.py index cd45a9634083f4f2ff6e1211881d1ab004dbe740..8981af8b7c9c5a2c71e6fd42a8bb439c4a05c70f 100644 --- a/transformer_engine/jax/dot.py +++ b/transformer_engine/jax/dot.py @@ -18,7 +18,7 @@ def type_safe_dot_general( x, kernel, fp8_meta_pkg: FP8MetaPackage = None, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)) + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), ) -> jnp.ndarray: """ Type safe dot_general, including FP8. @@ -56,13 +56,14 @@ def dequantize(x, dq_dtype, scale_inv): # Apply jit to guarantee correctness of FP8 GEMM. @partial(jax.jit, static_argnums=(4, 5, 6)) def fp8_dot_impl( - q_lhs: jnp.ndarray, - q_rhs: jnp.ndarray, - lhs_scale_inv: jnp.ndarray, - rhs_scale_inv: jnp.ndarray, - ctype: jnp.dtype, # computing type - contracting_dims: Tuple[Sequence[int], Sequence[int]], - precision: Precision = None): + q_lhs: jnp.ndarray, + q_rhs: jnp.ndarray, + lhs_scale_inv: jnp.ndarray, + rhs_scale_inv: jnp.ndarray, + ctype: jnp.dtype, # computing type + contracting_dims: Tuple[Sequence[int], Sequence[int]], + precision: Precision = None, +): """ FP8 GEMM for XLA pattern match """ @@ -82,37 +83,47 @@ def get_precision_of_fp8_dot(enable_2xACC: bool): @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6)) -def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, amax_list: List[jnp.ndarray], - scale_list: List[jnp.ndarray], fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, - contracting_dims: Tuple[Sequence[int], Sequence[int]]): - output, _ = _fp8_dot_fwd_rule(x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype, - contracting_dims) +def _fp8_dot( + x: jnp.ndarray, + kernel: jnp.ndarray, + amax_list: List[jnp.ndarray], + scale_list: List[jnp.ndarray], + fwd_dtype: jnp.dtype, + bwd_dtype: jnp.dtype, + contracting_dims: Tuple[Sequence[int], Sequence[int]], +): + output, _ = _fp8_dot_fwd_rule( + x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype, contracting_dims + ) return output def _fp8_dot_fwd_rule( - x, - kernel, - amax_list, - scale_list, - fwd_dtype, - bwd_dtype, # pylint: disable=unused-argument - contracting_dims): - - maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \ - FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list, *scale_list) + x, + kernel, + amax_list, + scale_list, + fwd_dtype, + bwd_dtype, # pylint: disable=unused-argument + contracting_dims, +): + + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( + *amax_list, *scale_list + ) amax_list = maybe_fm32_to_fp32(*amax_list) scale_list = maybe_fm32_to_fp32(*scale_list) lhs_contracting_dims, rhs_contracting_dims = contracting_dims - x_shape_suf = x.shape[min(lhs_contracting_dims):] - kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1] + x_shape_suf = x.shape[min(lhs_contracting_dims) :] + kernel_shape_pre = kernel.shape[: max(rhs_contracting_dims) + 1] assert x_shape_suf == kernel_shape_pre fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype] - scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(amax_list, scale_list, - fp8_dtype_list) + scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale( + amax_list, scale_list, fp8_dtype_list + ) amax_list = FP8MetaPackage.update_amax_list(amax_list) x_scale = scale_list[FP8MetaPackage.INPUT_IDX] @@ -127,52 +138,100 @@ def _fp8_dot_fwd_rule( # unnecessary copy to break FP8 GEMM pattern matching. casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale) - output = fp8_dot_impl(casted_x, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype, - (lhs_contracting_dims, rhs_contracting_dims), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - - ctx = (casted_x, casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax, - updated_kernel_amax, x.shape, kernel.shape, maybe_fp32_to_fm32) + output = fp8_dot_impl( + casted_x, + casted_kernel, + x_scale_inv, + kernel_scale_inv, + x.dtype, + (lhs_contracting_dims, rhs_contracting_dims), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP), + ) + + ctx = ( + casted_x, + casted_kernel, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + x.shape, + kernel.shape, + maybe_fp32_to_fm32, + ) return output, ctx -def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument +def _fp8_dot_bwd_rule( + fwd_dtype, bwd_dtype, contracting_dims, ctx, grad +): # pylint: disable=unused-argument lhs_contracting_dims, rhs_contracting_dims = contracting_dims - casted_x, casted_kernel, amax_list, scale_list, scale_inv_list, \ - updated_x_amax, updated_kernel_amax, x_shape, kernel_shape, \ - maybe_fp32_to_fm32 = ctx + ( + casted_x, + casted_kernel, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + x_shape, + kernel_shape, + maybe_fp32_to_fm32, + ) = ctx grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX] - casted_grad, casted_grad_t, updated_grad_amax = \ - tex.cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, - bwd_dtype, static_axis_boundary=-1, - transpose_axis_boundary=min(lhs_contracting_dims)) + casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=min(lhs_contracting_dims), + ) x_constracting_dim = tuple(range(0, len(x_shape) - len(lhs_contracting_dims))) gt_constracting_dim = tuple(range(grad.ndim - len(x_constracting_dim), grad.ndim)) x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] - wgrad = fp8_dot_impl(casted_x, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype, - (x_constracting_dim, gt_constracting_dim), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) + wgrad = fp8_dot_impl( + casted_x, + casted_grad_t, + x_scale_inv, + grad_scale_inv, + grad.dtype, + (x_constracting_dim, gt_constracting_dim), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD), + ) g_constracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim)) + range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim) + ) k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape))) kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype, - (g_constracting_dim, k_constracting_dim), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) - - amax_list[FP8MetaPackage.INPUT_IDX] = \ + dgrad = fp8_dot_impl( + casted_grad, + casted_kernel, + grad_scale_inv, + kernel_scale_inv, + grad.dtype, + (g_constracting_dim, k_constracting_dim), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD), + ) + + amax_list[FP8MetaPackage.INPUT_IDX] = ( amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax) - amax_list[FP8MetaPackage.WEIGHT_IDX] = \ + ) + amax_list[FP8MetaPackage.WEIGHT_IDX] = ( amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax) - amax_list[FP8MetaPackage.GRAD_IDX] = \ + ) + amax_list[FP8MetaPackage.GRAD_IDX] = ( amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) + ) amax_list = maybe_fp32_to_fm32(*amax_list) scale_list = maybe_fp32_to_fm32(*scale_list) diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 914bce00b3319b7a52026d85d2c9b0efcc73539c..6655091caa059c897e6bafe474d8ea4724d0e927 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -9,15 +9,15 @@ from .transformer import DotProductAttention, MultiHeadAttention, RelativePositi from .transformer import TransformerLayer, TransformerLayerType __all__ = [ - 'DenseGeneral', - 'LayerNorm', - 'LayerNormDenseGeneral', - 'LayerNormMLP', - 'TransformerEngineBase', - 'extend_logical_axis_rules', - 'DotProductAttention', - 'MultiHeadAttention', - 'RelativePositionBiases', - 'TransformerLayer', - 'TransformerLayerType', + "DenseGeneral", + "LayerNorm", + "LayerNormDenseGeneral", + "LayerNormMLP", + "TransformerEngineBase", + "extend_logical_axis_rules", + "DotProductAttention", + "MultiHeadAttention", + "RelativePositionBiases", + "TransformerLayer", + "TransformerLayerType", ] diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a5ffd01e8ae00f778b7cf12e94a4ebd098c8e135..e7388c20e0f97caca1bded09e5a870483e58052b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -30,8 +30,9 @@ PRNGKey = Any Shape = Tuple[int, ...] DType = jnp.dtype Array = jnp.ndarray -PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, - lax.Precision]] +PrecisionLike = Union[ + None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] +] Initializer = Callable[[PRNGKey, Shape, DType], Array] @@ -55,25 +56,22 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga return nn.initializers.zeros -def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init, - bias_axes, dtype): - scale = nn_partitioning.param_with_axes('scale', - scale_init, - shape, - jnp.float32, - axes=scale_axes) +def _create_layernorm_parameters( + layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype +): + scale = nn_partitioning.param_with_axes( + "scale", scale_init, shape, jnp.float32, axes=scale_axes + ) scale = jnp.asarray(scale, dtype) layernorm_type = canonicalize_layernorm_type(layernorm_type) - if layernorm_type == 'layernorm': - bias = nn_partitioning.param_with_axes('ln_bias', - bias_init, - shape, - jnp.float32, - axes=bias_axes) + if layernorm_type == "layernorm": + bias = nn_partitioning.param_with_axes( + "ln_bias", bias_init, shape, jnp.float32, axes=bias_axes + ) bias = jnp.asarray(bias, dtype) else: - assert layernorm_type == 'rmsnorm' + assert layernorm_type == "rmsnorm" bias = None return scale, bias @@ -81,7 +79,7 @@ def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable: """Convert a string to an activation function.""" - if fn_or_string == 'linear': + if fn_or_string == "linear": return lambda x: x if isinstance(fn_or_string, str): return getattr(nn, fn_or_string) @@ -96,8 +94,9 @@ def _combine_biases(*masks: List[Array]): masks = [m for m in masks if m is not None] if not masks: return None - assert all(map(lambda x: x.ndim == masks[0].ndim, - masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + assert all( + map(lambda x: x.ndim == masks[0].ndim, masks) + ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}" mask, *other_masks = masks for other_mask in other_masks: mask = mask + other_mask @@ -108,10 +107,10 @@ def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, """Low Rank Adaptation Implementation""" assert len(axis) <= 5 - hidden_in_names = 'ijklm'[:len(axis)] + hidden_in_names = "ijklm"[: len(axis)] assert len(features) <= 5 - hidden_out_names = 'nopqr'[:len(features)] - rank_name = 's' + hidden_out_names = "nopqr"[: len(features)] + rank_name = "s" assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2] rank = lora_a_kernel.shape[-1] @@ -121,15 +120,17 @@ def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}" lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}" output_einsum_express = f"...{hidden_out_names}" - final_einsum_express = f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}" \ - f"->{output_einsum_express}" + final_einsum_express = ( + f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}" + f"->{output_einsum_express}" + ) output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel) output = output * scaling return output -class Softmax(nn.Module): # pylint: disable=too-few-public-methods +class Softmax(nn.Module): # pylint: disable=too-few-public-methods r""" Applies softmax over a mini-batch of inputs. The input's shape should be [batch, heads, q_seqlen, k_seqlen]. @@ -160,8 +161,9 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods dtype = inputs.dtype logits = inputs - if (self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( - self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype)): + if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( + self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype + ): if bias is not None: logits = logits + bias.astype(dtype) @@ -174,9 +176,11 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods else: attention_bias = None if mask is not None: - attention_bias = lax.select(mask > 0, - jnp.full(mask.shape, -1e10).astype(dtype), - jnp.full(mask.shape, 0.).astype(dtype)) + attention_bias = lax.select( + mask > 0, + jnp.full(mask.shape, -1e10).astype(dtype), + jnp.full(mask.shape, 0.0).astype(dtype), + ) if bias is not None: attention_bias = _combine_biases(attention_bias, bias) @@ -186,8 +190,9 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED # and kernel is unavailable, then try on pure scaled softmax custom calls. - if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, - dtype): + if is_softmax_kernel_available( + SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype + ): outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) else: outputs = jax_nn.softmax(logits * self.scale_factor) @@ -195,7 +200,7 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods return outputs -class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods +class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods r""" Applies layer normalization over a mini-batch of inputs. There are two types of normalization supported by this module, @@ -262,19 +267,21 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). """ + epsilon: float = 1e-6 - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" zero_centered_gamma: bool = False scale_init: Initializer = None - scale_axes: Tuple[str, ...] = ('embed',) + scale_axes: Tuple[str, ...] = ("embed",) bias_init: Initializer = nn.initializers.zeros - bias_axes: Tuple[str, ...] = ('embed',) + bias_axes: Tuple[str, ...] = ("embed",) dtype: DType = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): self.scale_init = _obtain_default_layernorm_scale_init_if_need( - self.scale_init, self.zero_centered_gamma) + self.scale_init, self.zero_centered_gamma + ) super().__post_init__() @nn.compact @@ -294,18 +301,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods """ features = x.shape[-1] - scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), - self.scale_init, self.scale_axes, - self.bias_init, self.bias_axes, self.dtype) - return layernorm(x, - scale, - ln_bias, - layernorm_type=self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon) - - -class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods + scale, ln_bias = _create_layernorm_parameters( + self.layernorm_type, + (features,), + self.scale_init, + self.scale_axes, + self.bias_init, + self.bias_axes, + self.dtype, + ) + return layernorm( + x, + scale, + ln_bias, + layernorm_type=self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + epsilon=self.epsilon, + ) + + +class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods """ Base class of transformer engine """ @@ -321,18 +336,23 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-meth grad_name_post_fix = f"_g_{postfix}" def generate_a_set(target_postfix): - amax = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME, - f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}", - jnp.zeros, (FP8Helper.AMAX_HISTORY_LEN,), - jnp.float32, - axes=(None,)) + amax = nn_partitioning.variable_with_axes( + FP8Helper.FP8_COLLECTION_NAME, + f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}", + jnp.zeros, + (FP8Helper.AMAX_HISTORY_LEN,), + jnp.float32, + axes=(None,), + ) scale = nn_partitioning.variable_with_axes( FP8Helper.FP8_COLLECTION_NAME, f"{FP8Helper.FP8_SCALE_NAME}{target_postfix}", - jnp.ones, (1,), + jnp.ones, + (1,), jnp.float32, - axes=(None,)) + axes=(None,), + ) return amax.value, scale.value @@ -340,8 +360,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-meth weight_amax, weight_scale = generate_a_set(weight_name_post_fix) grad_amax, grad_scale = generate_a_set(grad_name_post_fix) - return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax, - grad_scale) + return FP8MetaPackage( + input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale + ) class DenseGeneral(TransformerEngineBase): @@ -403,7 +424,7 @@ class DenseGeneral(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') + self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") super().__post_init__() @nn.compact @@ -430,20 +451,16 @@ class DenseGeneral(TransformerEngineBase): kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features - kernel = nn_partitioning.param_with_axes('kernel', - self.kernel_init, - kernel_param_shape, - jnp.float32, - axes=self.kernel_axes) + kernel = nn_partitioning.param_with_axes( + "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + ) kernel = jnp.reshape(kernel, kernel_shape) if self.use_bias: - bias = nn_partitioning.param_with_axes('bias', - self.bias_init, - features, - jnp.float32, - axes=self.bias_axes) + bias = nn_partitioning.param_with_axes( + "bias", self.bias_init, features, jnp.float32, axes=self.bias_axes + ) bias = bias.astype(self.dtype) else: bias = None @@ -453,36 +470,46 @@ class DenseGeneral(TransformerEngineBase): if FP8Helper.is_fp8_enabled(): fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0") - y = type_safe_dot_general(inputs, - kernel, - fp8_meta_pkg=fp8_meta_pkg, - contracting_dims=(axis, contract_ind)) + y = type_safe_dot_general( + inputs, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind) + ) if self.enable_low_rank_adaptation: - lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1], - self.low_rank_adaptation_dim) - lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1], - self.low_rank_adaptation_dim) + lora_a_kernel_shape = ( + *kernel_shape[: len(axis)], + *features[:-1], + self.low_rank_adaptation_dim, + ) + lora_a_kernel_init_shape = ( + kernel_param_shape[0], + *features[:-1], + self.low_rank_adaptation_dim, + ) lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) - lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel', - self.kernel_init, - lora_a_kernel_init_shape, - jnp.float32, - axes=lora_a_kernel_axes) + lora_a_kernel = nn_partitioning.param_with_axes( + "lora_a_kernel", + self.kernel_init, + lora_a_kernel_init_shape, + jnp.float32, + axes=lora_a_kernel_axes, + ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = lora_a_kernel.astype(self.dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) - lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel', - nn.initializers.zeros, - lora_b_kernel_shape, - jnp.float32, - axes=lora_b_kernel_axes) + lora_b_kernel = nn_partitioning.param_with_axes( + "lora_b_kernel", + nn.initializers.zeros, + lora_b_kernel_shape, + jnp.float32, + axes=lora_b_kernel_axes, + ) lora_b_kernel = lora_b_kernel.astype(self.dtype) - y += _apply_low_rank_adaptation(inputs, axis, features, lora_a_kernel, lora_b_kernel, - self.low_rank_adaptation_alpha) + y += _apply_low_rank_adaptation( + inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha + ) if bias is not None: bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape @@ -581,13 +608,13 @@ class LayerNormDenseGeneral(TransformerEngineBase): features: Union[Iterable[int], int] enable_layernorm: bool = True - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" epsilon: float = 1e-6 zero_centered_gamma: bool = False scale_init: Initializer = None - scale_axes: Tuple[str, ...] = ('embed',) + scale_axes: Tuple[str, ...] = ("embed",) ln_bias_init: Initializer = nn.initializers.zeros - ln_bias_axes: Tuple[str, ...] = ('embed',) + ln_bias_axes: Tuple[str, ...] = ("embed",) kernel_init: Initializer = None kernel_axes: Tuple[str, ...] = () use_bias: bool = False @@ -606,9 +633,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') + self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") self.scale_init = _obtain_default_layernorm_scale_init_if_need( - self.scale_init, self.zero_centered_gamma) + self.scale_init, self.zero_centered_gamma + ) super().__post_init__() @nn.compact @@ -632,27 +660,37 @@ class LayerNormDenseGeneral(TransformerEngineBase): ln_output = None - fuse_layernorm = FP8Helper.is_fp8_enabled( - ) and not self.return_layernorm_output and self.enable_layernorm + fuse_layernorm = ( + FP8Helper.is_fp8_enabled() + and not self.return_layernorm_output + and self.enable_layernorm + ) if self.enable_layernorm: inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) - assert self.axis == -1 # Only support axis = =-1 at this moment + assert self.axis == -1 # Only support axis = =-1 at this moment features = inputs.shape[-1] - scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), - self.scale_init, self.scale_axes, - self.ln_bias_init, self.ln_bias_axes, - self.dtype) + scale, ln_bias = _create_layernorm_parameters( + self.layernorm_type, + (features,), + self.scale_init, + self.scale_axes, + self.ln_bias_init, + self.ln_bias_axes, + self.dtype, + ) if not fuse_layernorm: - y = layernorm(inputs, - scale, - ln_bias, - layernorm_type=self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon) + y = layernorm( + inputs, + scale, + ln_bias, + layernorm_type=self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + epsilon=self.epsilon, + ) else: assert not self.return_layernorm_output y = inputs @@ -670,11 +708,9 @@ class LayerNormDenseGeneral(TransformerEngineBase): kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features - kernel = nn_partitioning.param_with_axes('kernel', - self.kernel_init, - kernel_param_shape, - jnp.float32, - axes=self.kernel_axes) + kernel = nn_partitioning.param_with_axes( + "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + ) kernel = jnp.reshape(kernel, kernel_shape) @@ -685,56 +721,66 @@ class LayerNormDenseGeneral(TransformerEngineBase): fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0") if fuse_layernorm: - z = layernorm_fp8_dot(y, - kernel, - scale, - ln_bias, - fp8_meta_pkg, - self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - layernorm_input_axes=self.layernorm_input_axes, - dot_input_axes=self.dot_input_axes) + z = layernorm_fp8_dot( + y, + kernel, + scale, + ln_bias, + fp8_meta_pkg, + self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + epsilon=self.epsilon, + layernorm_input_axes=self.layernorm_input_axes, + dot_input_axes=self.dot_input_axes, + ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) - z = type_safe_dot_general(y, - kernel, - fp8_meta_pkg=fp8_meta_pkg, - contracting_dims=(axis, contract_ind)) + z = type_safe_dot_general( + y, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind) + ) if self.enable_low_rank_adaptation: - lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1], - self.low_rank_adaptation_dim) - lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1], - self.low_rank_adaptation_dim) + lora_a_kernel_shape = ( + *kernel_shape[: len(axis)], + *features[:-1], + self.low_rank_adaptation_dim, + ) + lora_a_kernel_init_shape = ( + kernel_param_shape[0], + *features[:-1], + self.low_rank_adaptation_dim, + ) lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) - lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel', - self.kernel_init, - lora_a_kernel_init_shape, - jnp.float32, - axes=lora_a_kernel_axes) + lora_a_kernel = nn_partitioning.param_with_axes( + "lora_a_kernel", + self.kernel_init, + lora_a_kernel_init_shape, + jnp.float32, + axes=lora_a_kernel_axes, + ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = lora_a_kernel.astype(self.dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) - lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel', - nn.initializers.zeros, - lora_b_kernel_shape, - jnp.float32, - axes=lora_b_kernel_axes) + lora_b_kernel = nn_partitioning.param_with_axes( + "lora_b_kernel", + nn.initializers.zeros, + lora_b_kernel_shape, + jnp.float32, + axes=lora_b_kernel_axes, + ) lora_b_kernel = lora_b_kernel.astype(self.dtype) - z += _apply_low_rank_adaptation(y, axis, features, lora_a_kernel, lora_b_kernel, - self.low_rank_adaptation_alpha) + z += _apply_low_rank_adaptation( + y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha + ) bias = None if self.use_bias: - bias = nn_partitioning.param_with_axes('bias', - self.bias_init, - features, - jnp.float32, - axes=self.bias_axes) + bias = nn_partitioning.param_with_axes( + "bias", self.bias_init, features, jnp.float32, axes=self.bias_axes + ) bias = bias.astype(self.dtype) if bias is not None: @@ -744,7 +790,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): if self.depth_scaling is not None: z = z / self.depth_scaling - return z, ln_output # dense_output, layer_norm_output + return z, ln_output # dense_output, layer_norm_output class LayerNormMLP(TransformerEngineBase): @@ -858,23 +904,23 @@ class LayerNormMLP(TransformerEngineBase): intermediate_dim: int = 2048 enable_layernorm: bool = True - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" epsilon: float = 1e-6 zero_centered_gamma: bool = False scale_init: Initializer = None - scale_axes: Tuple[str, ...] = ('embed',) + scale_axes: Tuple[str, ...] = ("embed",) ln_bias_init: Initializer = nn.initializers.zeros - ln_bias_axes: Tuple[str, ...] = ('embed',) + ln_bias_axes: Tuple[str, ...] = ("embed",) kernel_init: Initializer = None - kernel_axes_1: Tuple[str, ...] = ('embed', 'act', 'mlp') - kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed') + kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp") + kernel_axes_2: Tuple[str, ...] = ("mlp", "embed") use_bias: bool = False bias_init: Initializer = nn.initializers.zeros - bias_axes_1: Tuple[str, ...] = ('act', 'mlp') - bias_axes_2: Tuple[str, ...] = ('embed',) + bias_axes_1: Tuple[str, ...] = ("act", "mlp") + bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = ('relu',) - intermediate_dropout_rng_name: str = 'dropout' + activations: Sequence[Union[str, Callable]] = ("relu",) + intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () enable_low_rank_adaptation: bool = False @@ -889,9 +935,10 @@ class LayerNormMLP(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') + self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") self.scale_init = _obtain_default_layernorm_scale_init_if_need( - self.scale_init, self.zero_centered_gamma) + self.scale_init, self.zero_centered_gamma + ) super().__post_init__() @nn.compact @@ -917,44 +964,61 @@ class LayerNormMLP(TransformerEngineBase): ln_output = None - fuse_layernorm = FP8Helper.is_fp8_enabled( - ) and not self.return_layernorm_output and self.enable_layernorm - - gated_act_pool = [('gelu', 'linear'), ('silu', 'linear'), ('relu', 'linear'), - ('quick_gelu', 'linear'), ('squared_relu', 'linear')] - act_pool = [('gelu',), ('silu',), ('relu',), ('quick_gelu',), ('squared_relu',)] + fuse_layernorm = ( + FP8Helper.is_fp8_enabled() + and not self.return_layernorm_output + and self.enable_layernorm + ) + + gated_act_pool = [ + ("gelu", "linear"), + ("silu", "linear"), + ("relu", "linear"), + ("quick_gelu", "linear"), + ("squared_relu", "linear"), + ] + act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] normalized_acts = [] for act in self.activations: if not isinstance(act, str): return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if normalized_acts[0] == 'linear' else normalized_acts) + reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts + ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) - use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\ - self.intermediate_dropout_rate < 1e-3 + use_fused_layernorm_mlp = ( + fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3 + ) # LayerNorm if self.enable_layernorm: - assert self.axis == -1 # Only support axis == -1 at this moment + assert self.axis == -1 # Only support axis == -1 at this moment inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) features = inputs.shape[-1] - scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), - self.scale_init, self.scale_axes, - self.ln_bias_init, self.ln_bias_axes, - self.dtype) + scale, ln_bias = _create_layernorm_parameters( + self.layernorm_type, + (features,), + self.scale_init, + self.scale_axes, + self.ln_bias_init, + self.ln_bias_axes, + self.dtype, + ) if not fuse_layernorm: - y = layernorm(inputs, - scale, - ln_bias, - layernorm_type=self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon) + y = layernorm( + inputs, + scale, + ln_bias, + layernorm_type=self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + epsilon=self.epsilon, + ) else: assert not self.return_layernorm_output y = inputs @@ -984,125 +1048,149 @@ class LayerNormMLP(TransformerEngineBase): intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim)) kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) - kernel_1 = nn_partitioning.param_with_axes('wi_kernel', - kernel_1_init, - num_activations, - -2, - kernel_1_each_shape, - jnp.float32, - axes=self.kernel_axes_1) + kernel_1 = nn_partitioning.param_with_axes( + "wi_kernel", + kernel_1_init, + num_activations, + -2, + kernel_1_each_shape, + jnp.float32, + axes=self.kernel_axes_1, + ) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple)) - kernel_2 = nn_partitioning.param_with_axes('wo_kernel', - self.kernel_init, - kernel_2_param_shape, - jnp.float32, - axes=self.kernel_axes_2) + kernel_2 = nn_partitioning.param_with_axes( + "wo_kernel", + self.kernel_init, + kernel_2_param_shape, + jnp.float32, + axes=self.kernel_axes_2, + ) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) contract_ind = tuple(range(0, len(axis))) - ffn1_ckpt_name = 'ffn1' - ffn2_ckpt_name = 'ffn2' + ffn1_ckpt_name = "ffn1" + ffn2_ckpt_name = "ffn2" if use_fused_layernorm_mlp: - assert self.axis == -1 # Only support axis = =-1 at this moment + assert self.axis == -1 # Only support axis = =-1 at this moment if self.use_bias: bias_1_shape = intermediate_dim - bias_1 = nn_partitioning.param_with_axes('wi_bias', - self.bias_init, - bias_1_shape, - jnp.float32, - axes=self.bias_axes_1) + bias_1 = nn_partitioning.param_with_axes( + "wi_bias", self.bias_init, bias_1_shape, jnp.float32, axes=self.bias_axes_1 + ) bias_1 = bias_1.astype(self.dtype) bias_2_shape = (hidden_size,) - bias_2 = nn_partitioning.param_with_axes('wo_bias', - self.bias_init, - bias_2_shape, - jnp.float32, - axes=self.bias_axes_2) + bias_2 = nn_partitioning.param_with_axes( + "wo_bias", self.bias_init, bias_2_shape, jnp.float32, axes=self.bias_axes_2 + ) bias_2 = bias_2.astype(self.dtype) else: bias_1 = None bias_2 = None - out = fused_layernorm_fp8_mlp(y, - scale, - ln_bias, [kernel_1, kernel_2], [bias_1, bias_2], - [wi_fp8_meta_pkg, wo_fp8_meta_pkg], - self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - layernorm_input_axes=self.layernorm_input_axes, - dot_1_input_axes=self.dot_1_input_axes, - dot_2_input_axes=self.dot_2_input_axes, - ffn1_ckpt_name=ffn1_ckpt_name, - ffn2_ckpt_name=ffn2_ckpt_name, - activation_type=normalized_acts, - use_bias=self.use_bias) - - else: # not use_fused_ln_geglu_mlp + out = fused_layernorm_fp8_mlp( + y, + scale, + ln_bias, + [kernel_1, kernel_2], + [bias_1, bias_2], + [wi_fp8_meta_pkg, wo_fp8_meta_pkg], + self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + epsilon=self.epsilon, + layernorm_input_axes=self.layernorm_input_axes, + dot_1_input_axes=self.dot_1_input_axes, + dot_2_input_axes=self.dot_2_input_axes, + ffn1_ckpt_name=ffn1_ckpt_name, + ffn2_ckpt_name=ffn2_ckpt_name, + activation_type=normalized_acts, + use_bias=self.use_bias, + ) + + else: # not use_fused_ln_geglu_mlp # DenseGeneral 1 if fuse_layernorm: - x = layernorm_fp8_dot(y, - kernel_1, - scale, - ln_bias, - wi_fp8_meta_pkg, - self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - layernorm_input_axes=self.layernorm_input_axes, - dot_input_axes=self.dot_1_input_axes) + x = layernorm_fp8_dot( + y, + kernel_1, + scale, + ln_bias, + wi_fp8_meta_pkg, + self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + epsilon=self.epsilon, + layernorm_input_axes=self.layernorm_input_axes, + dot_input_axes=self.dot_1_input_axes, + ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) - x = type_safe_dot_general(y, - kernel_1, - fp8_meta_pkg=wi_fp8_meta_pkg, - contracting_dims=(axis, contract_ind)) + x = type_safe_dot_general( + y, kernel_1, fp8_meta_pkg=wi_fp8_meta_pkg, contracting_dims=(axis, contract_ind) + ) if self.enable_low_rank_adaptation: - wi_lora_a_kernel_shape = (*kernel_1_shape[:len(axis)], num_activations, - self.low_rank_adaptation_dim) - wi_lora_a_kernel_init_shape = (kernel_1_each_shape[0], num_activations, - self.low_rank_adaptation_dim) - wi_lora_a_kernel_init_each_shape = (kernel_1_each_shape[0], - self.low_rank_adaptation_dim) + wi_lora_a_kernel_shape = ( + *kernel_1_shape[: len(axis)], + num_activations, + self.low_rank_adaptation_dim, + ) + wi_lora_a_kernel_init_shape = ( + kernel_1_each_shape[0], + num_activations, + self.low_rank_adaptation_dim, + ) + wi_lora_a_kernel_init_each_shape = ( + kernel_1_each_shape[0], + self.low_rank_adaptation_dim, + ) wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape) - wi_lora_a_kernel = nn_partitioning.param_with_axes('wi_lora_a_kernel', - kernel_1_init, - num_activations, - -2, - wi_lora_a_kernel_init_each_shape, - jnp.float32, - axes=wi_lora_a_kernel_axes) + wi_lora_a_kernel = nn_partitioning.param_with_axes( + "wi_lora_a_kernel", + kernel_1_init, + num_activations, + -2, + wi_lora_a_kernel_init_each_shape, + jnp.float32, + axes=wi_lora_a_kernel_axes, + ) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype) - wi_lora_b_kernel_shape = (num_activations, self.low_rank_adaptation_dim, - self.intermediate_dim) + wi_lora_b_kernel_shape = ( + num_activations, + self.low_rank_adaptation_dim, + self.intermediate_dim, + ) wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape) - wi_lora_b_kernel = nn_partitioning.param_with_axes('wi_lora_b_kernel', - nn.initializers.zeros, - wi_lora_b_kernel_shape, - jnp.float32, - axes=wi_lora_b_kernel_axes) + wi_lora_b_kernel = nn_partitioning.param_with_axes( + "wi_lora_b_kernel", + nn.initializers.zeros, + wi_lora_b_kernel_shape, + jnp.float32, + axes=wi_lora_b_kernel_axes, + ) wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) - x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel, - wi_lora_b_kernel, self.low_rank_adaptation_alpha) + x += _apply_low_rank_adaptation( + y, + axis, + intermediate_dim, + wi_lora_a_kernel, + wi_lora_b_kernel, + self.low_rank_adaptation_alpha, + ) bias_1 = None if self.use_bias: - bias_1 = nn_partitioning.param_with_axes('wi_bias', - self.bias_init, - intermediate_dim, - jnp.float32, - axes=self.bias_axes_1) + bias_1 = nn_partitioning.param_with_axes( + "wi_bias", self.bias_init, intermediate_dim, jnp.float32, axes=self.bias_axes_1 + ) bias_1 = bias_1.astype(self.dtype) bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape x += jnp.reshape(bias_1, bias_1_shape) @@ -1120,50 +1208,59 @@ class LayerNormMLP(TransformerEngineBase): # Remove act axis z = jnp.reshape(z, (*z.shape[:-2], -1)) - z = nn.Dropout(rate=self.intermediate_dropout_rate, - broadcast_dims=self.intermediate_hidden_dropout_dims, - rng_collection=self.intermediate_dropout_rng_name)( - z, deterministic=deterministic) + z = nn.Dropout( + rate=self.intermediate_dropout_rate, + broadcast_dims=self.intermediate_hidden_dropout_dims, + rng_collection=self.intermediate_dropout_rng_name, + )(z, deterministic=deterministic) z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes) # DenseGeneral 2 - out = type_safe_dot_general(z, - kernel_2, - fp8_meta_pkg=wo_fp8_meta_pkg, - contracting_dims=(axis, contract_ind)) + out = type_safe_dot_general( + z, kernel_2, fp8_meta_pkg=wo_fp8_meta_pkg, contracting_dims=(axis, contract_ind) + ) if self.enable_low_rank_adaptation: wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim) wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape) - wo_lora_a_kernel = nn_partitioning.param_with_axes('wo_lora_a_kernel', - self.kernel_init, - wo_lora_a_kernel_shape, - jnp.float32, - axes=wo_lora_a_kernel_axes) + wo_lora_a_kernel = nn_partitioning.param_with_axes( + "wo_lora_a_kernel", + self.kernel_init, + wo_lora_a_kernel_shape, + jnp.float32, + axes=wo_lora_a_kernel_axes, + ) wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size) wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape) - wo_lora_b_kernel = nn_partitioning.param_with_axes('wo_lora_b_kernel', - nn.initializers.zeros, - wo_lora_b_kernel_shape, - jnp.float32, - axes=wo_lora_b_kernel_axes) + wo_lora_b_kernel = nn_partitioning.param_with_axes( + "wo_lora_b_kernel", + nn.initializers.zeros, + wo_lora_b_kernel_shape, + jnp.float32, + axes=wo_lora_b_kernel_axes, + ) wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) - out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel, - wo_lora_b_kernel, self.low_rank_adaptation_alpha) + out += _apply_low_rank_adaptation( + z, + axis, + hidden_size_tuple, + wo_lora_a_kernel, + wo_lora_b_kernel, + self.low_rank_adaptation_alpha, + ) bias_2 = None if self.use_bias: - bias_2 = nn_partitioning.param_with_axes('wo_bias', - self.bias_init, (hidden_size,), - jnp.float32, - axes=self.bias_axes_2) + bias_2 = nn_partitioning.param_with_axes( + "wo_bias", self.bias_init, (hidden_size,), jnp.float32, axes=self.bias_axes_2 + ) bias_2 = bias_2.astype(self.dtype) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out = checkpoint_name(out, ffn2_ckpt_name) - return out, ln_output # Output, layner_norm_output + return out, ln_output # Output, layner_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 224f61f0f4a466f03537a5d5a7ca5506b5592bf3..b98f2ed7b4ae871ca67d7c4eb98c22f6e48c266f 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -39,8 +39,9 @@ PRNGKey = Any Shape = Tuple[int, ...] DType = jnp.dtype Array = jnp.ndarray -PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, - lax.Precision]] +PrecisionLike = Union[ + None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] +] Initializer = Callable[[PRNGKey, Shape, DType], Array] LogicalRules = Sequence[Tuple[str, Union[str, None]]] @@ -82,14 +83,13 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: """ rules_map = {} for item in rules: - assert len(item) == 2, \ - "The logical axis rule should be like (axis_name, mesh_axis_name)." + assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)." key = item[0] val = item[1] - assert isinstance(key, str), \ - f"Thie axis_name should be str, but got {type(key)}." - assert isinstance(val, str) or (val is None), \ - f"Thie mesh_axis_name should be str or None, but got {type(val)}." + assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}." + assert isinstance(val, str) or ( + val is None + ), f"Thie mesh_axis_name should be str or None, but got {type(val)}." if key in rules_map: rules_map[key].append(val) else: @@ -100,17 +100,18 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: key = item[0] val = item[1] if key in rules_map: - assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \ - f"The rule diverged between TE and given rule." \ - f"Axis:{key} map to {rules_map[key]} in the given" \ + assert len(rules_map[key]) == 1 and rules_map[key][0] == val, ( + "The rule diverged between TE and given rule." + f"Axis:{key} map to {rules_map[key]} in the given" f" rules, but {val} in TE's rules." + ) else: extended_rules.append(item) return tuple(extended_rules) -class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods - attention_dropout: float = 0. +class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods + attention_dropout: float = 0.0 attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 @@ -119,23 +120,26 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi transpose_batch_sequence: bool = True @nn.compact - def __call__(self, - query: Array, - key: Array, - value: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - *, - dropout_rng: Optional[PRNGKey] = None, - deterministic: bool = False) -> Array: - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + def __call__( + self, + query: Array, + key: Array, + value: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + dropout_rng: Optional[PRNGKey] = None, + deterministic: bool = False, + ) -> Array: + assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank." batch_dim = 1 if self.transpose_batch_sequence else 0 - assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( - 'q, k, v batch dims must match.') + assert ( + query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim] + ), "q, k, v batch dims must match." sequence_dim = 0 if self.transpose_batch_sequence else 1 - assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' - assert key.shape[-2] == value.shape[-2], 'k, v num_attention_heads must match.' - assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.' + assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match." + assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match." + assert query.shape[-1] == key.shape[-1], "q, k head_dim must match." if self.scale_factor is None: scale_factor = 1.0 / sqrt(query.shape[-1]) @@ -149,7 +153,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi h_q, h_kv = query.shape[-2], key.shape[-2] # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. # Therefore, we have to maintain two code paths. - is_gqa = (h_q != h_kv) + is_gqa = h_q != h_kv if is_gqa: assert (h_q % h_kv == 0) and (h_q >= h_kv) @@ -158,16 +162,16 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi if self.transpose_batch_sequence: if is_gqa: - attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) + attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key) else: - attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key) + attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key) else: if is_gqa: - attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) + attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key) else: - attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) + attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key) - attn_weights = checkpoint_name(attn_weights, 'logits') + attn_weights = checkpoint_name(attn_weights, "logits") if is_gqa: b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape @@ -175,13 +179,14 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) attn_weights = with_sharding_constraint_by_logical_axes( - attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)) + attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES) + ) # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) # In this case, the scale can not fused into the Softmax module. if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: attn_weights = attn_weights * scale_factor - fused_scale_factor = 1. + fused_scale_factor = 1.0 else: # If not post_scale_bias, the scale can be fused into Softmax module fused_scale_factor = scale_factor @@ -199,39 +204,40 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi if mask is not None: return SoftmaxType.SCALED_MASKED, mask return SoftmaxType.SCALED, mask - raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type=" - "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}") + raise ValueError( + f"Unsupported {attn_mask_type=}, supported attn_mask_type=" + "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}" + ) softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask) - attn_weights = Softmax(softmax_type=softmax_type, - scale_factor=fused_scale_factor)(attn_weights, mask, - bias).astype(self.dtype) + attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)( + attn_weights, mask, bias + ).astype(self.dtype) if is_gqa: attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) - if not deterministic and self.attention_dropout > 0.: + if not deterministic and self.attention_dropout > 0.0: keep_prob = 1.0 - self.attention_dropout dropout_shape = list(attn_weights.shape) # TODO(rewang): add attention dropout broadcast dimension arguments for users keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) - multiplier = (keep.astype(attn_weights.dtype) / - jnp.asarray(keep_prob, dtype=self.dtype)) + multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype) attn_weights = attn_weights * multiplier if self.transpose_batch_sequence: if is_gqa: - return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) - return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value) + return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape) + return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value) if is_gqa: - return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) - return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) + return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape) + return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) -class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods - attention_dropout: float = 0. +class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods + attention_dropout: float = 0.0 attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 @@ -240,15 +246,17 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- transpose_batch_sequence: bool = False @nn.compact - def __call__(self, - query: Array, - key: Array, - value: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - *, - dropout_rng: Optional[PRNGKey] = None, - deterministic: bool = False) -> Array: + def __call__( + self, + query: Array, + key: Array, + value: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + dropout_rng: Optional[PRNGKey] = None, + deterministic: bool = False, + ) -> Array: seed = None if dropout_rng is not None: @@ -269,15 +277,17 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- qkv_packed = query if self.transpose_batch_sequence: qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) - x = fused_attn_qkvpacked(qkv_packed, - bias, - mask, - seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic) + x = fused_attn_qkvpacked( + qkv_packed, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic, + ) elif self.qkv_layout == QKVLayout.BSHD_BS2HD: """kvpacked format, treat query: query tensor, shape = [..., h, d] @@ -288,32 +298,36 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) - x = fused_attn_kvpacked(query, - kv_packed, - bias, - mask, - seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic) + x = fused_attn_kvpacked( + query, + kv_packed, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic, + ) elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) key = key.transpose([1, 0, 2, 3]) value = value.transpose([1, 0, 2, 3]) - x = fused_attn(query, - key, - value, - bias, - mask, - seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic) + x = fused_attn( + query, + key, + value, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic, + ) else: raise ValueError(f"Unsupported {self.qkv_layout=}.") @@ -323,7 +337,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- return x -class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods +class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods r""" Dot Product Attention (DPA). Allows the model to jointly attend to information from different representation subspaces as described in the paper: @@ -423,28 +437,31 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. """ + head_dim: int num_attention_heads: int num_gqa_groups: Optional[int] = None - attention_dropout: float = 0. - attn_mask_type: AttnMaskType = 'causal' + attention_dropout: float = 0.0 + attn_mask_type: AttnMaskType = "causal" attn_bias_type: AttnBiasType = None dtype: DType = jnp.float32 - dropout_rng_name: str = 'dropout' + dropout_rng_name: str = "dropout" float32_logits: bool = False - qkv_layout: str = 'bshd_bshd_bshd' + qkv_layout: str = "bshd_bshd_bshd" scale_factor: Optional[float] = None transpose_batch_sequence: bool = True @nn.compact - def __call__(self, - query: Array, - key: Array, - value: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - *, - deterministic: bool = False) -> Array: + def __call__( + self, + query: Array, + key: Array, + value: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + deterministic: bool = False, + ) -> Array: """ Parameters ---------- @@ -494,25 +511,34 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method else: seqlen_kv = key.shape[sequence_dim] - has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout, - attn_bias_type, attn_mask_type, - self.attention_dropout, - self.num_attention_heads, - self.num_gqa_groups, seqlen_q, - seqlen_kv, self.head_dim) + has_fused_attn_kernel = is_fused_attn_kernel_available( + self.dtype, + self.dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + self.attention_dropout, + self.num_attention_heads, + self.num_gqa_groups, + seqlen_q, + seqlen_kv, + self.head_dim, + ) - use_fused_attn = (enable_fused_attn and has_fused_attn_kernel) + use_fused_attn = enable_fused_attn and has_fused_attn_kernel if enable_fused_attn and not has_fused_attn_kernel: - warnings.warn("Fused attention is not enabled because there is no available kernel.\n" - "Fall back to the unfused attention.\n" - "Please try to update the cuDNN and TE to the latest version.\n" - f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" - f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" - f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n") + warnings.warn( + "Fused attention is not enabled because there is no available kernel.\n" + "Fall back to the unfused attention.\n" + "Please try to update the cuDNN and TE to the latest version.\n" + f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" + f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" + f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n" + ) dropout_rng = None - if not deterministic and self.attention_dropout > 0.: + if not deterministic and self.attention_dropout > 0.0: dropout_rng = self.make_rng(self.dropout_rng_name) if self.scale_factor is None: @@ -525,28 +551,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method # unfused attention only supports splitted query, key, value if qkv_layout == QKVLayout.BS3HD: query, key, value = jnp.split(query, [1, 2], axis=-3) - query, key, value = map(functools.partial(jnp.squeeze, axis=-3), - [query, key, value]) + query, key, value = map( + functools.partial(jnp.squeeze, axis=-3), [query, key, value] + ) elif qkv_layout == QKVLayout.BSHD_BS2HD: key, value = jnp.split(key, [1], axis=-3) key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) else: assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD - x = _UnfusedDotProductAttention(attention_dropout=self.attention_dropout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - dtype=self.dtype, - float32_logits=self.float32_logits, - scale_factor=scale_factor, - transpose_batch_sequence=self.transpose_batch_sequence)( - query, - key, - value, - mask, - bias, - dropout_rng=dropout_rng, - deterministic=deterministic) + x = _UnfusedDotProductAttention( + attention_dropout=self.attention_dropout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + dtype=self.dtype, + float32_logits=self.float32_logits, + scale_factor=scale_factor, + transpose_batch_sequence=self.transpose_batch_sequence, + )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) else: x = _FusedDotProductAttention( attention_dropout=self.attention_dropout, @@ -561,10 +583,12 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method return x -def rotary_pos_emb(x: Array, - windows: Tuple[int, int], - transpose_batch_sequence: bool, - group_method: str = 'consecutive'): +def rotary_pos_emb( + x: Array, + windows: Tuple[int, int], + transpose_batch_sequence: bool, + group_method: str = "consecutive", +): """ Rotary Positional Embedding x should be in shape of @@ -577,7 +601,7 @@ def rotary_pos_emb(x: Array, max_window = windows[1] fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim - time_scales = min_window * (max_window / min_window)**fraction + time_scales = min_window * (max_window / min_window) ** fraction time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1))) batch_dim = 1 if transpose_batch_sequence else 0 @@ -623,21 +647,22 @@ def rotary_pos_emb(x: Array, return output def canonicalize_group_method(gm): - canonicalized_gm = gm.lower().strip().replace('-', '').replace('_', '') - assert canonicalized_gm in ['consecutive', 'alternate'], \ - f"Invalid relative positional embedding group method. " \ + canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "") + assert canonicalized_gm in ["consecutive", "alternate"], ( + "Invalid relative positional embedding group method. " f"Expect to be in []'alternate' or 'consecutive'], but got {gm}." + ) return canonicalized_gm group_method = canonicalize_group_method(group_method) - if group_method == 'alternate': + if group_method == "alternate": return alternate_impl() return consecutive_impl() -class LoRAScope: # pylint: disable=too-few-public-methods +class LoRAScope: # pylint: disable=too-few-public-methods """LoRA Scope""" def __init__(self, qkv_proj=False, output_proj=False, mlp=False): @@ -646,28 +671,37 @@ class LoRAScope: # pylint: disable=too-few-public-methods self.mlp = mlp def __eq__(self, other): - return (self.qkv_proj, self.output_proj, self.mlp) == \ - (other.qkv_proj, other.output_proj, other.mlp) + return (self.qkv_proj, self.output_proj, self.mlp) == ( + other.qkv_proj, + other.output_proj, + other.mlp, + ) def _canonicalize_lora_scope(scope): - SCOPE_NONE = 'none' - SCOPE_ALL = 'all' - SCOPE_QKV_PROJ = 'qkv_proj' - SCOPE_OUTPUT_PROJ = 'output_proj' - SCOPE_MLP = 'mlp' - SCOPE_EX_QKV_PROJ = 'exclude_qkv_proj' - SCOPE_EX_OUTPUT_PROJ = 'exclude_output_proj' - SCOPE_EX_MLP = 'exclude_mlp' + SCOPE_NONE = "none" + SCOPE_ALL = "all" + SCOPE_QKV_PROJ = "qkv_proj" + SCOPE_OUTPUT_PROJ = "output_proj" + SCOPE_MLP = "mlp" + SCOPE_EX_QKV_PROJ = "exclude_qkv_proj" + SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj" + SCOPE_EX_MLP = "exclude_mlp" scope = SCOPE_NONE if scope is None else scope scope = scope.lower() assert scope in [ - SCOPE_NONE, SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_OUTPUT_PROJ, SCOPE_MLP, SCOPE_EX_QKV_PROJ, - SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP + SCOPE_NONE, + SCOPE_ALL, + SCOPE_QKV_PROJ, + SCOPE_OUTPUT_PROJ, + SCOPE_MLP, + SCOPE_EX_QKV_PROJ, + SCOPE_EX_OUTPUT_PROJ, + SCOPE_EX_MLP, ] lora_scope = LoRAScope() @@ -684,7 +718,7 @@ def _canonicalize_lora_scope(scope): return lora_scope -class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods +class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods r""" Multi-head Attention (MHA), including Query, Key, Value and Output projection. @@ -818,8 +852,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods head_dim: int num_attention_heads: int num_gqa_groups: Optional[int] = None - attention_dropout: float = 0. - dropout_rng_name: str = 'dropout' + attention_dropout: float = 0.0 + dropout_rng_name: str = "dropout" input_layernorm: bool = True layernorm_type: str = "layernorm" layernorm_epsilon: float = 1e-6 @@ -828,12 +862,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods kernel_init: Initializer = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros - attn_mask_type: str = 'causal' + attn_mask_type: str = "causal" attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) - rotary_pos_emb_group_method: str = 'consecutive' - low_rank_adaptation_scope: str = 'none' + rotary_pos_emb_group_method: str = "consecutive" + low_rank_adaptation_scope: str = "none" low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 @@ -857,40 +891,50 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods self.num_attention_heads = self.num_heads warnings.warn( f"{__class__}.num_heads is deprecated. It will be removed recently. " - f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning) + f"Please uses {__class__}.num_attention_heads as the new API.", + DeprecationWarning, + ) if self.dropout_rate is not None: self.attention_dropout = self.dropout_rate warnings.warn( f"{__class__}.dropout_rate is deprecated. It will be removed recently. " - f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning) + f"Please use {__class__}.attention_dropout as the new API.", + DeprecationWarning, + ) if self.apply_residual_connection_post_layernorm is not None: warnings.warn( f"{__class__}.apply_residual_connection_post_layernorm is deprecated. " f"It will be removed recently, please use {__class__}.return_layernorm_output.", - DeprecationWarning) + DeprecationWarning, + ) if self.fuse_qkv is not None: warnings.warn( f"{__class__}.fuse_qkv is deprecated. It will be removed recently. " - f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning) + f"Please use {__class__}.fuse_qkv_params as the new API.", + DeprecationWarning, + ) assert self.output_layernorm is None, ( f"{__class__}.output_layernorm is deprecated. It will be removed recently. " - f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.") + f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm." + ) if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') + self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @nn.compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - *, - decode: bool = False, - deterministic: bool = False) -> Array: + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + decode: bool = False, + deterministic: bool = False, + ) -> Array: """ MultiHeadAttention Layer: [Query, Key, Value projection] -> Dot Product Attention -> Output projection. @@ -963,12 +1007,14 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES return tuple(axes) - is_self_attn = (inputs_q is inputs_kv) - is_gqa = (self.num_attention_heads != self.num_gqa_groups) - is_qkvpack = (is_self_attn and not is_gqa) + is_self_attn = inputs_q is inputs_kv + is_gqa = self.num_attention_heads != self.num_gqa_groups + is_qkvpack = is_self_attn and not is_gqa - inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes( - self.enable_sequence_parallel), HIDDEN_AXES) + inputs_logical_axes_maybe_sp = ( + *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel), + HIDDEN_AXES, + ) inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES) inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp) @@ -998,9 +1044,10 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, - name='qkv', - dtype=self.dtype)(inputs_q) - qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj') + name="qkv", + dtype=self.dtype, + )(inputs_q) + qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj") qkv_layout = QKVLayout.BS3HD else: query, ln_out = LayerNormDenseGeneral( @@ -1025,26 +1072,29 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, - name='query')(inputs_q) + name="query", + )(inputs_q) if is_self_attn: assert ln_out is not None inputs_kv = ln_out - kv_proj = DenseGeneral(axis=-1, - features=(2, self.num_gqa_groups * self.head_dim), - transpose_batch_sequence=self.transpose_batch_sequence, - kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), - kernel_init=kv_init, - use_bias=self.use_bias, - bias_init=self.bias_init, - bias_axes=(W_JOINED_AXES, W_TP_AXES), - enable_low_rank_adaptation=lora_scope.qkv_proj, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - name='kv', - dtype=self.dtype)(inputs_kv) - kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj') + kv_proj = DenseGeneral( + axis=-1, + features=(2, self.num_gqa_groups * self.head_dim), + transpose_batch_sequence=self.transpose_batch_sequence, + kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), + kernel_init=kv_init, + use_bias=self.use_bias, + bias_init=self.bias_init, + bias_axes=(W_JOINED_AXES, W_TP_AXES), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, + name="kv", + dtype=self.dtype, + )(inputs_kv) + kv_proj = checkpoint_name(kv_proj, "combined_kv_proj") qkv_layout = QKVLayout.BSHD_BS2HD else: kv_projection = functools.partial( @@ -1059,7 +1109,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods enable_low_rank_adaptation=lora_scope.qkv_proj, low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - dtype=self.dtype) + dtype=self.dtype, + ) query, ln_out = LayerNormDenseGeneral( enable_layernorm=self.input_layernorm, layernorm_type=self.layernorm_type, @@ -1082,17 +1133,18 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, - name='query')(inputs_q) + name="query", + )(inputs_q) if is_self_attn: assert ln_out is not None inputs_kv = ln_out - key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) - value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) - query = checkpoint_name(query, 'query_proj') - key = checkpoint_name(key, 'key_proj') - value = checkpoint_name(value, 'value_proj') + key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv) + value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv) + query = checkpoint_name(query, "query_proj") + key = checkpoint_name(key, "key_proj") + value = checkpoint_name(value, "value_proj") qkv_layout = QKVLayout.BSHD_BSHD_BSHD if self.enable_rotary_pos_emb: @@ -1107,10 +1159,18 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) - query = rotary_pos_emb(query, self.rotary_pos_emb_windows, - self.transpose_batch_sequence, self.rotary_pos_emb_group_method) - key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence, - self.rotary_pos_emb_group_method) + query = rotary_pos_emb( + query, + self.rotary_pos_emb_windows, + self.transpose_batch_sequence, + self.rotary_pos_emb_group_method, + ) + key = rotary_pos_emb( + key, + self.rotary_pos_emb_windows, + self.transpose_batch_sequence, + self.rotary_pos_emb_group_method, + ) qkv_layout = QKVLayout.BSHD_BSHD_BSHD if qkv_layout == QKVLayout.BSHD_BSHD_BSHD: @@ -1120,13 +1180,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods if decode: assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD - is_initialized = self.has_variable('cache', 'cached_key') - - cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) - cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, - value.dtype) - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.int32)) + is_initialized = self.has_variable("cache", "cached_key") + + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable( + "cache", "cached_value", jnp.zeros, value.shape, value.dtype + ) + cache_index = self.variable( + "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32) + ) if is_initialized: if self.transpose_batch_sequence: length, batch, num_attention_heads, head_dim = cached_key.value.shape @@ -1140,8 +1202,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods # Sanity shape check of cached key against input query. if expected_shape != query.shape: raise ValueError( - 'Autoregressive cache shape error, ' - f"expected query shape {expected_shape} instead got {query.shape}.") + "Autoregressive cache shape error, " + f"expected query shape {expected_shape} instead got {query.shape}." + ) cur_index = cache_index.value one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype) @@ -1153,21 +1216,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods cache_index.value = cache_index.value + 1 mask = combine_masks( - mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))) + mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)) + ) if bias is not None: - dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, - in_axes=(None, 0, None, None)) - bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), - jnp.reshape(cur_index, (-1)), 1, -2) + dynamic_vector_slice_in_dim = vmap( + lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None) + ) + bias = dynamic_vector_slice_in_dim( + jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2 + ) LEADING_AXES = (BATCH_AXES, SEQLEN_AXES) if self.transpose_batch_sequence: LEADING_AXES = (SEQLEN_AXES, BATCH_AXES) if qkv_layout == QKVLayout.BS3HD: - qkv_proj = qkv_proj.reshape(*qkv_proj.shape[:2], 3, self.num_attention_heads, - self.head_dim) + qkv_proj = qkv_proj.reshape( + *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim + ) qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES) qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint) dpa_args = [qkv_proj, None, None] @@ -1191,43 +1258,46 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods dpa_args = [query, key, value] scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0 - x = DotProductAttention(head_dim=self.head_dim, - num_attention_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - attention_dropout=self.attention_dropout, - dtype=self.dtype, - dropout_rng_name=self.dropout_rng_name, - float32_logits=self.float32_logits, - qkv_layout=qkv_layout.name, - scale_factor=scale_factor, - transpose_batch_sequence=self.transpose_batch_sequence)( - *dpa_args, mask, bias, deterministic=deterministic) + x = DotProductAttention( + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_gqa_groups=self.num_gqa_groups, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + attention_dropout=self.attention_dropout, + dtype=self.dtype, + dropout_rng_name=self.dropout_rng_name, + float32_logits=self.float32_logits, + qkv_layout=qkv_layout.name, + scale_factor=scale_factor, + transpose_batch_sequence=self.transpose_batch_sequence, + )(*dpa_args, mask, bias, deterministic=deterministic) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES) x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint) - out = DenseGeneral(features=inputs_q.shape[-1], - transpose_batch_sequence=self.transpose_batch_sequence, - axis=-1, - kernel_init=self.kernel_init, - kernel_axes=(W_TP_AXES, W_FSDP_AXES), - use_bias=self.use_bias, - bias_init=self.bias_init, - bias_axes=(W_NO_SHARD_AXES,), - enable_low_rank_adaptation=lora_scope.output_proj, - low_rank_adaptation_dim=self.low_rank_adaptation_dim, - low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, - dtype=self.dtype, - name='out')(x) - out = checkpoint_name(out, 'out_proj') + out = DenseGeneral( + features=inputs_q.shape[-1], + transpose_batch_sequence=self.transpose_batch_sequence, + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=(W_TP_AXES, W_FSDP_AXES), + use_bias=self.use_bias, + bias_init=self.bias_init, + bias_axes=(W_NO_SHARD_AXES,), + enable_low_rank_adaptation=lora_scope.output_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, + dtype=self.dtype, + name="out", + )(x) + out = checkpoint_name(out, "out_proj") return out, ln_out -class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods +class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods """ T5-style relative positional embeddings to the attention logits. @@ -1250,11 +1320,12 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. """ + num_buckets: int max_distance: int num_attention_heads: int embedding_init: Callable[..., Array] = nn.linear.default_embed_init - embedding_axes: Tuple[str, ...] = ('heads', 'relpos_buckets') + embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets") dtype: DType = jnp.float32 @nn.compact @@ -1296,26 +1367,30 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met rpb_max_exact = rpb_num_buckets // 2 rpb_is_small = negative_rp < rpb_max_exact rpb_val_if_large = rpb_max_exact + ( - np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) / - np.log(self.max_distance / rpb_max_exact) * - (rpb_num_buckets - rpb_max_exact)).astype(np.int32) + np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) + / np.log(self.max_distance / rpb_max_exact) + * (rpb_num_buckets - rpb_max_exact) + ).astype(np.int32) rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1) rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large) # Compute relative attention bias relative_attention_bias = nn_partitioning.param_with_axes( - 'rel_embedding', - self.embedding_init, (self.num_attention_heads, self.num_buckets), + "rel_embedding", + self.embedding_init, + (self.num_attention_heads, self.num_buckets), jnp.float32, - axes=self.embedding_axes) + axes=self.embedding_axes, + ) relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) - values = lax.dot_general(relative_attention_bias, rp_bucket_one_hot, - (((1,), (0,)), ((), ()))) + values = lax.dot_general( + relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ())) + ) return values[jnp.newaxis, ...] @@ -1330,11 +1405,12 @@ class TransformerLayerType(Enum): DECODER: Decoder type of TransformerLayer. """ + ENCODER = "encoder" DECODER = "decoder" -class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods +class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods r""" TransformerLayer is made up of a relative embedding, an attention block and a feedforward network (MLP). @@ -1497,7 +1573,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mlp_hidden_size: int = 2048 num_attention_heads: int = 8 num_gqa_groups: Optional[int] = None - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" layernorm_epsilon: float = 1e-6 zero_centered_gamma: bool = False hidden_dropout: float = 0.1 @@ -1505,24 +1581,24 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods attention_dropout: float = 0.1 intermediate_dropout: float = 0.1 intermediate_dropout_dims: Sequence[int] = () - dropout_rng_name: str = 'dropout' + dropout_rng_name: str = "dropout" mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None - mlp_activations: Sequence[str] = ('relu',) + mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False float32_attention_logits: bool = False layer_type: TransformerLayerType = TransformerLayerType.ENCODER - self_attn_mask_type: str = 'causal' + self_attn_mask_type: str = "causal" self_attn_bias_type: Optional[str] = None enable_relative_embedding: bool = True relative_embedding: nn.Module = None enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) - rotary_pos_emb_group_method: str = 'consecutive' - low_rank_adaptation_scope: str = 'none' + rotary_pos_emb_group_method: str = "consecutive" + low_rank_adaptation_scope: str = "none" low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 @@ -1535,23 +1611,26 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods def __post_init__(self): if self.mha_kernel_init is None: - self.mha_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') + self.mha_kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") if self.mlp_kernel_init is None: - self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', - 'truncated_normal') + self.mlp_kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal" + ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @nn.compact - def __call__(self, - inputs: Array, - encoded: Array = None, - attention_mask: Array = None, - encoder_decoder_mask: Array = None, - deterministic: bool = False, - decode: bool = False, - max_decode_length: bool = None): + def __call__( + self, + inputs: Array, + encoded: Array = None, + attention_mask: Array = None, + encoder_decoder_mask: Array = None, + deterministic: bool = False, + decode: bool = False, + max_decode_length: bool = None, + ): """ Transformer Layer: attention block and a feedforward network (MLP) @@ -1585,17 +1664,18 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods outputs: jax.numpy.ndarray Output tensors. """ - assert self.layer_type in TransformerLayerType, \ - "layer_type should be one of TransformerLayerType" \ - f", but got {self.layer_type}." + assert ( + self.layer_type in TransformerLayerType + ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}." - assert self.hidden_size % self.num_attention_heads == 0, \ - "hidden_size should be multiples of num_attention_heads" \ - f", but got {self.hidden_size=} and {self.num_attention_heads=}." + assert self.hidden_size % self.num_attention_heads == 0, ( + "hidden_size should be multiples of num_attention_heads" + f", but got {self.hidden_size=} and {self.num_attention_heads=}." + ) - assert self.layer_type == TransformerLayerType.DECODER or \ - (self.layer_type == TransformerLayerType.ENCODER and decode is False), \ - "decode should be False when layer_type == TransformerLayerType.ENCODER." + assert self.layer_type == TransformerLayerType.DECODER or ( + self.layer_type == TransformerLayerType.ENCODER and decode is False + ), "decode should be False when layer_type == TransformerLayerType.ENCODER." head_dim = self.hidden_size // self.num_attention_heads @@ -1605,8 +1685,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods def generate_batch_seqlen_logical_axes(is_shared_seq=None): axes = [None, None] - is_shared_seq = self.enable_sequence_parallel if is_shared_seq is None \ - else is_shared_seq + is_shared_seq = ( + self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq + ) axes[batch_dim] = BATCH_AXES axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES @@ -1615,13 +1696,14 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods attn_bias = None if self.enable_relative_embedding: if self.relative_embedding is None: - rel_emb = RelativePositionBiases(num_buckets=32, - max_distance=128, - num_attention_heads=self.num_attention_heads, - dtype=self.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform'), - name='relpos_bias') + rel_emb = RelativePositionBiases( + num_buckets=32, + max_distance=128, + num_attention_heads=self.num_attention_heads, + dtype=self.dtype, + embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), + name="relpos_bias", + ) else: rel_emb = self.relative_embedding @@ -1639,12 +1721,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods # Make name be the exactly same as T5X, since names would affect # RNGKey during init and apply. Myabe no need in the feature. if self.layer_type == TransformerLayerType.ENCODER: - mha_name = 'attention' + mha_name = "attention" else: - mha_name = 'self_attention' + mha_name = "self_attention" inputs = with_sharding_constraint_by_logical_axes( - inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) # [batch, length, emb_dim] -> [batch, length, emb_dim] residual = inputs @@ -1677,12 +1760,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods kernel_init=self.mha_kernel_init, use_bias=self.use_bias, bias_init=self.bias_init, - name=mha_name)(inputs, - inputs, - attention_mask, - attn_bias, - deterministic=deterministic, - decode=decode) + name=mha_name, + )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) def hidden_dropout(x, deterministic): assert isinstance(self.hidden_dropout_dims, Sequence) @@ -1690,21 +1769,27 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods for dims in self.hidden_dropout_dims: assert -x_shape_len <= dims < x_shape_len - return nn.Dropout(rate=self.hidden_dropout, - broadcast_dims=self.hidden_dropout_dims, - rng_collection=self.dropout_rng_name)(x, deterministic=deterministic) + return nn.Dropout( + rate=self.hidden_dropout, + broadcast_dims=self.hidden_dropout_dims, + rng_collection=self.dropout_rng_name, + )(x, deterministic=deterministic) x = with_sharding_constraint_by_logical_axes( - x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) residual = with_sharding_constraint_by_logical_axes( - residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) x = hidden_dropout(x, deterministic) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) - x = nn.Dropout(rate=self.drop_path, - broadcast_dims=drop_path_shape, - rng_collection=self.dropout_rng_name)(x, deterministic=deterministic) + x = nn.Dropout( + rate=self.drop_path, + broadcast_dims=drop_path_shape, + rng_collection=self.dropout_rng_name, + )(x, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: assert ln_out is not None @@ -1714,11 +1799,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mlp_input = x if self.layer_type == TransformerLayerType.DECODER: - assert encoded is not None, \ - "encoded is required when layer_type == TransformerLayerType.DECODER." + assert ( + encoded is not None + ), "encoded is required when layer_type == TransformerLayerType.DECODER." x = with_sharding_constraint_by_logical_axes( - x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) residual = x y, ln_out = MultiHeadAttention( @@ -1734,9 +1821,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods layernorm_epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, return_layernorm_output=self.apply_residual_connection_post_layernorm, - input_layernorm=True, # Must do LayerNorm before MHA. - attn_mask_type='padding', - attn_bias_type='no_bias', + input_layernorm=True, # Must do LayerNorm before MHA. + attn_mask_type="padding", + attn_bias_type="no_bias", enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, @@ -1750,15 +1837,15 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods kernel_init=self.mha_kernel_init, use_bias=self.use_bias, bias_init=self.bias_init, - name='encoder_decoder_attention')(x, - encoded, - encoder_decoder_mask, - deterministic=deterministic) + name="encoder_decoder_attention", + )(x, encoded, encoder_decoder_mask, deterministic=deterministic) y = with_sharding_constraint_by_logical_axes( - y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) residual = with_sharding_constraint_by_logical_axes( - residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) y = hidden_dropout(y, deterministic) @@ -1769,7 +1856,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mlp_input = y + residual mlp_input = with_sharding_constraint_by_logical_axes( - mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope) @@ -1802,7 +1890,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), - name='mlp', + name="mlp", )(mlp_input, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: @@ -1810,27 +1898,33 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods residual = ln_out z = with_sharding_constraint_by_logical_axes( - z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) residual = with_sharding_constraint_by_logical_axes( - residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) z = hidden_dropout(z, deterministic) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim) - z = nn.Dropout(rate=self.drop_path, - broadcast_dims=drop_path_shape)(z, deterministic=deterministic) + z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)( + z, deterministic=deterministic + ) z = z + residual if self.output_layernorm: z = with_sharding_constraint_by_logical_axes( - z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) - z = LayerNorm(layernorm_type=self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.layernorm_epsilon, - scale_axes=(W_NO_SHARD_AXES,), - bias_axes=(W_NO_SHARD_AXES,), - transpose_batch_sequence=self.transpose_batch_sequence, - dtype=self.dtype, - name="output_layernorm")(z) + z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) + ) + z = LayerNorm( + layernorm_type=self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + epsilon=self.layernorm_epsilon, + scale_axes=(W_NO_SHARD_AXES,), + bias_axes=(W_NO_SHARD_AXES,), + transpose_batch_sequence=self.transpose_batch_sequence, + dtype=self.dtype, + name="output_layernorm", + )(z) return z diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 291531f90518c60f873f063a80b744cf43d33d50..4766203f69596ff8d05fbdba8087f9e5c858c5cd 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -32,9 +32,9 @@ Collection = Union[Dict, FrozenDict] def _check_fp8_support(gpu_id) -> Tuple[bool, str]: """Return if fp8 support is available""" gpu_arch = get_device_compute_capability(gpu_id) - if gpu_arch >= 90: # hopper and above + if gpu_arch >= 90: # hopper and above return True, "" - if gpu_arch < 89: # pre-ada + if gpu_arch < 89: # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." if get_cublasLt_version() < 120103: return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." @@ -135,8 +135,8 @@ class FP8MetaPackage: @staticmethod def update_fp8_scale( - amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray], - fp8_dtype_list: List[DType]) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: + amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray], fp8_dtype_list: List[DType] + ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: """ Get update scale and scale_inv list """ @@ -151,6 +151,7 @@ class FP8MetaPackage: class AmaxComputeAlgo(Enum): """AmaxComputeAlgo.""" + MAX = "max" MOST_RECENT = "most_recent" @@ -162,6 +163,7 @@ class FP8Helper: """ FP8 helper to manage the FP8 meta """ + INITIALIZED = False MARGIN: float = 0.0 FP8_FORMAT: Format = Format.HYBRID @@ -184,18 +186,19 @@ class FP8Helper: return FP8Helper.INITIALIZED @staticmethod - def initialize(margin: float = 0.0, - fp8_format: Format = Format.HYBRID, - amax_history_len: int = 1, - amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None: + def initialize( + margin: float = 0.0, + fp8_format: Format = Format.HYBRID, + amax_history_len: int = 1, + amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX, + ) -> None: """ Initialize the FP8 meta """ FP8Helper.INITIALIZED = True FP8Helper.MARGIN = margin FP8Helper.FP8_FORMAT = fp8_format - FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \ - _format2dtypes(FP8Helper.FP8_FORMAT) + FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT) FP8Helper.AMAX_HISTORY_LEN = amax_history_len FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo FP8Helper.FP8_2X_ACC_FPROP = False @@ -210,8 +213,7 @@ class FP8Helper: FP8Helper.INITIALIZED = False FP8Helper.MARGIN = 0.0 FP8Helper.FP8_FORMAT = Format.HYBRID - FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \ - _format2dtypes(FP8Helper.FP8_FORMAT) + FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT) FP8Helper.AMAX_HISTORY_LEN = 1024 FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX @@ -300,9 +302,11 @@ class FP8Helper: @contextmanager -def fp8_autocast(enabled: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, - mesh_resource: Optional[MeshResource] = None) -> None: +def fp8_autocast( + enabled: bool = False, + fp8_recipe: Optional[DelayedScaling] = None, + mesh_resource: Optional[MeshResource] = None, +) -> None: r""" Context manager for FP8 usage. @@ -344,13 +348,18 @@ def fp8_autocast(enabled: bool = False, fp8_recipe = DelayedScaling() assert fp8_recipe.amax_compute_algo in [ - "max", "most_recent" - ], ("DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.") - assert fp8_recipe.scaling_factor_compute_algo is None, ( - "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.") - assert fp8_recipe.override_linear_precision == (False, False, False), ( - "DelayedScaling override_linear_precision isn't supported by TE/JAX.") - assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.") + "max", + "most_recent", + ], "DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX." + assert ( + fp8_recipe.scaling_factor_compute_algo is None + ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX." + assert fp8_recipe.override_linear_precision == ( + False, + False, + False, + ), "DelayedScaling override_linear_precision isn't supported by TE/JAX." + assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX." if mesh_resource is None: mesh_resource = MeshResource() @@ -362,13 +371,15 @@ def fp8_autocast(enabled: bool = False, assert fp8_available, reason_for_no_fp8 amax_compute_algo = AmaxComputeAlgo.MOST_RECENT - if fp8_recipe.amax_compute_algo == 'max': + if fp8_recipe.amax_compute_algo == "max": amax_compute_algo = AmaxComputeAlgo.MAX - FP8Helper.initialize(margin=fp8_recipe.margin, - fp8_format=fp8_recipe.fp8_format, - amax_history_len=fp8_recipe.amax_history_len, - amax_compute_algo=amax_compute_algo) + FP8Helper.initialize( + margin=fp8_recipe.margin, + fp8_format=fp8_recipe.fp8_format, + amax_history_len=fp8_recipe.amax_history_len, + amax_compute_algo=amax_compute_algo, + ) yield finally: FP8Helper.finalize() @@ -410,9 +421,12 @@ def get_delayed_scaling(): delay_scaling : DelayedScaling an instance of DelayedScaling which is set via fp8_autocast. """ - amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \ - else "most_recent" - return DelayedScaling(margin=int(FP8Helper.MARGIN), - fp8_format=FP8Helper.FP8_FORMAT, - amax_history_len=FP8Helper.AMAX_HISTORY_LEN, - amax_compute_algo=amax_compute_algo) + amax_compute_algo = ( + "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" + ) + return DelayedScaling( + margin=int(FP8Helper.MARGIN), + fp8_format=FP8Helper.FP8_FORMAT, + amax_history_len=FP8Helper.AMAX_HISTORY_LEN, + amax_compute_algo=amax_compute_algo, + ) diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index da50448fead0743bd40bfd6cb794fbfd792abf4a..e7364a13b6ba71c93890420ef6d27863914733a8 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -16,56 +16,55 @@ from .sharding import with_sharding_constraint_by_logical_axes def canonicalize_layernorm_type(x): - ''' + """ Canonicalize the layernorm type - ''' - canonicalized = x.lower().strip().replace('-', '').replace('_', '') - assert canonicalized in ['layernorm', 'rmsnorm'] + """ + canonicalized = x.lower().strip().replace("-", "").replace("_", "") + assert canonicalized in ["layernorm", "rmsnorm"] return canonicalized -def layernorm(inputs: jnp.ndarray, - gamma: jnp.ndarray, - beta: jnp.ndarray, - layernorm_type: str, - zero_centered_gamma: bool = False, - epsilon: float = 1e-6): +def layernorm( + inputs: jnp.ndarray, + gamma: jnp.ndarray, + beta: jnp.ndarray, + layernorm_type: str, + zero_centered_gamma: bool = False, + epsilon: float = 1e-6, +): """ LN/RMSNorm wrapper Only support layernorm_type in ['layernorm', 'rmsnorm'] """ - output = _layernorm(inputs, - gamma, - beta, - layernorm_type=layernorm_type, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + output = _layernorm( + inputs, + gamma, + beta, + layernorm_type=layernorm_type, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, + ) return output @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def _layernorm(x, - gamma, - beta, - layernorm_type: str, - zero_centered_gamma: bool = False, - epsilon: float = 1e-6): +def _layernorm( + x, gamma, beta, layernorm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6 +): output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon) return output -def _layernorm_fwd_rule(x, - gamma, - beta, - layernorm_type: str, - zero_centered_gamma: bool = False, - epsilon: float = 1e-6): +def _layernorm_fwd_rule( + x, gamma, beta, layernorm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6 +): layernorm_type = canonicalize_layernorm_type(layernorm_type) - if layernorm_type == 'layernorm': + if layernorm_type == "layernorm": output, mu, rsigma = tex.layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon) - elif layernorm_type == 'rmsnorm': - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" + elif layernorm_type == "rmsnorm": + assert ( + not zero_centered_gamma + ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'" output, rsigma = tex.rmsnorm_fwd(x, gamma, epsilon) mu = None else: @@ -75,17 +74,14 @@ def _layernorm_fwd_rule(x, def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz): x, mu, rsigma, gamma = ctx - if layernorm_type == 'layernorm': - dx, dgamma, dbeta = tex.layernorm_bwd(dz, - x, - mu, - rsigma, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) - elif layernorm_type == 'rmsnorm': - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" + if layernorm_type == "layernorm": + dx, dgamma, dbeta = tex.layernorm_bwd( + dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ) + elif layernorm_type == "rmsnorm": + assert ( + not zero_centered_gamma + ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'" dx, dgamma = tex.rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon) dbeta = None else: @@ -107,9 +103,11 @@ def layernorm_fp8_dot( zero_centered_gamma: bool = False, epsilon: float = 1e-6, layernorm_input_axes: Tuple[ - str, ...] = None, # The logic axes of sharding constraint to the layernorm input. - dot_input_axes: Tuple[str, - ...] = None # The logic axes of sharding constraint to the dot input. + str, ... + ] = None, # The logic axes of sharding constraint to the layernorm input. + dot_input_axes: Tuple[ + str, ... + ] = None, # The logic axes of sharding constraint to the dot input. ) -> jnp.ndarray: """ Layernorm + FP8 GEMM @@ -118,26 +116,41 @@ def layernorm_fp8_dot( scale_list = fp8_meta_pkg.scale_list fwd_dtype = FP8Helper.FWD_DTYPE bwd_dtype = FP8Helper.BWD_DTYPE - output = _layernorm_fp8_dot(x, kernel, gamma, beta, amax_list, scale_list, layernorm_type, - fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon, - layernorm_input_axes, dot_input_axes) + output = _layernorm_fp8_dot( + x, + kernel, + gamma, + beta, + amax_list, + scale_list, + layernorm_type, + fwd_dtype, + bwd_dtype, + zero_centered_gamma, + epsilon, + layernorm_input_axes, + dot_input_axes, + ) return output @partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12)) -def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, - amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray], - layernorm_type: str, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, - zero_centered_gamma: bool, epsilon: float, - layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...]): - output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, amax_list, scale_list, - layernorm_type, fwd_dtype, bwd_dtype, - zero_centered_gamma, epsilon, layernorm_input_axes, - dot_input_axes) - return output - - -def _layernorm_fp8_dot_fwd_rule( +def _layernorm_fp8_dot( + x: jnp.ndarray, + kernel: jnp.ndarray, + gamma: jnp.ndarray, + beta: jnp.ndarray, + amax_list: List[jnp.ndarray], + scale_list: List[jnp.ndarray], + layernorm_type: str, + fwd_dtype: jnp.dtype, + bwd_dtype: jnp.dtype, + zero_centered_gamma: bool, + epsilon: float, + layernorm_input_axes: Tuple[str, ...], + dot_input_axes: Tuple[str, ...], +): + output, _ = _layernorm_fp8_dot_fwd_rule( x, kernel, gamma, @@ -146,24 +159,45 @@ def _layernorm_fp8_dot_fwd_rule( scale_list, layernorm_type, fwd_dtype, - bwd_dtype, # pylint: disable=unused-argument + bwd_dtype, zero_centered_gamma, epsilon, layernorm_input_axes, - dot_input_axes): + dot_input_axes, + ) + return output + + +def _layernorm_fp8_dot_fwd_rule( + x, + kernel, + gamma, + beta, + amax_list, + scale_list, + layernorm_type, + fwd_dtype, + bwd_dtype, # pylint: disable=unused-argument + zero_centered_gamma, + epsilon, + layernorm_input_axes, + dot_input_axes, +): x_contracting_dims = (len(x.shape) - 1,) k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] - maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \ - FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list, *scale_list) + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( + *amax_list, *scale_list + ) amax_list = maybe_fm32_to_fp32(*amax_list) scale_list = maybe_fm32_to_fp32(*scale_list) fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype] - scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(amax_list, scale_list, - fp8_dtype_list) + scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale( + amax_list, scale_list, fp8_dtype_list + ) amax_list = FP8MetaPackage.update_amax_list(amax_list) x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1] @@ -172,7 +206,7 @@ def _layernorm_fp8_dot_fwd_rule( x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) - if layernorm_type == 'layernorm': + if layernorm_type == "layernorm": ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8( x, gamma, @@ -182,17 +216,15 @@ def _layernorm_fp8_dot_fwd_rule( x_scale_inv, out_dtype=fwd_dtype, zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + epsilon=epsilon, + ) else: - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(x, - gamma, - x_amax, - x_scale, - x_scale_inv, - out_dtype=fwd_dtype, - epsilon=epsilon) + assert ( + not zero_centered_gamma + ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'" + ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8( + x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon + ) mu = None assert x.shape == ln_out.shape @@ -204,37 +236,74 @@ def _layernorm_fp8_dot_fwd_rule( # Kernel in (hidden_in, hidden_out...) # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding # unnecessary copy to break FP8 GEMM pattern matching. - casted_kernel, updated_kernel_amax = \ - tex.cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype) + casted_kernel, updated_kernel_amax = tex.cast_fp8( + kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype + ) ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes) # (batch..., hidden_in) x (hidden_in, hidden_out...) - output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype, - (x_contracting_dims, k_contracting_dims), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - - ctx = (ln_out, casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax, - updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims, - k_contracting_dims, maybe_fp32_to_fm32) + output = fp8_dot_impl( + ln_out, + casted_kernel, + x_scale_inv, + kernel_scale_inv, + x.dtype, + (x_contracting_dims, k_contracting_dims), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP), + ) + + ctx = ( + ln_out, + casted_kernel, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + x.shape, + kernel.shape, + mu, + rsigma, + x, + gamma, + x_contracting_dims, + k_contracting_dims, + maybe_fp32_to_fm32, + ) return output, ctx def _layernorm_fp8_dot_bwd_rule( - layernorm_type, - fwd_dtype, # pylint: disable=unused-argument - bwd_dtype, - zero_centered_gamma, - epsilon, - layernorm_input_axes, - dot_input_axes, # pylint: disable=unused-argument - ctx, - grad): - ln_out_, casted_kernel, amax_list, scale_list, scale_inv_list, \ - updated_x_amax, updated_kernel_amax, \ - x_shape, kernel_shape, mu, rsigma, x, gamma, \ - x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32 = ctx + layernorm_type, + fwd_dtype, # pylint: disable=unused-argument + bwd_dtype, + zero_centered_gamma, + epsilon, + layernorm_input_axes, + dot_input_axes, # pylint: disable=unused-argument + ctx, + grad, +): + ( + ln_out_, + casted_kernel, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + x_shape, + kernel_shape, + mu, + rsigma, + x, + gamma, + x_contracting_dims, + k_contracting_dims, + maybe_fp32_to_fm32, + ) = ctx ln_out_t = tex.transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1) @@ -242,53 +311,70 @@ def _layernorm_fp8_dot_bwd_rule( grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX] - casted_grad, casted_grad_t, updated_grad_amax = \ - tex.cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype, - static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims)) + casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=min(x_contracting_dims), + ) xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape))) gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim)) x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] - wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype, - (xt_constracting_dim, gt_constracting_dim), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) + wgrad = fp8_dot_impl( + ln_out_t, + casted_grad_t, + x_scale_inv, + grad_scale_inv, + grad.dtype, + (xt_constracting_dim, gt_constracting_dim), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD), + ) g_for_dgrad_constracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim)) + range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim) + ) k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape))) kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype, - (g_for_dgrad_constracting_dim, k_constracting_dim), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) + dgrad = fp8_dot_impl( + casted_grad, + casted_kernel, + grad_scale_inv, + kernel_scale_inv, + grad.dtype, + (g_for_dgrad_constracting_dim, k_constracting_dim), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD), + ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) - if layernorm_type == 'layernorm': - dx, dgamma, dbeta = tex.layernorm_bwd(dgrad, - x, - mu, - rsigma, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + if layernorm_type == "layernorm": + dx, dgamma, dbeta = tex.layernorm_bwd( + dgrad, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ) else: - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" + assert ( + not zero_centered_gamma + ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'" dx, dgamma = tex.rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon) dbeta = None - amax_list[FP8MetaPackage.INPUT_IDX] = \ + amax_list[FP8MetaPackage.INPUT_IDX] = ( amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) - amax_list[FP8MetaPackage.WEIGHT_IDX] = \ + ) + amax_list[FP8MetaPackage.WEIGHT_IDX] = ( amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) - amax_list[FP8MetaPackage.GRAD_IDX] = \ + ) + amax_list[FP8MetaPackage.GRAD_IDX] = ( amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) + ) amax_list = maybe_fp32_to_fm32(*amax_list) scale_list = maybe_fp32_to_fm32(*scale_list) - return dx, wgrad, \ - dgamma, dbeta, \ - amax_list, scale_list + return dx, wgrad, dgamma, dbeta, amax_list, scale_list _layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 0d77ed2cf54aafb8fd5e2dfe650432f04d62f1a0..0017acb80cfe77a77af2cb8accd27a632649f1d2 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -22,7 +22,7 @@ def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable] Activation Unit """ if len(activation_type) > 1: - assert x.shape[-2] == 2 # Linear + GeLU + assert x.shape[-2] == 2 # Linear + GeLU output = _activation_lu(x, activation_type) return output @@ -41,7 +41,7 @@ def _activation_lu_fwd_rule(x, activation_type): def _activation_lu_bwd_rule(activation_type, ctx, g): - x, = ctx + (x,) = ctx assert x.dtype == g.dtype dx = tex.dact_lu(g, x, activation_type) @@ -52,22 +52,24 @@ def _activation_lu_bwd_rule(activation_type, ctx, g): _activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule) -def fused_layernorm_fp8_mlp(x: jnp.ndarray, - gamma: jnp.ndarray, - beta: jnp.ndarray, - kernels: List[jnp.ndarray], - biases: List[jnp.ndarray], - fp8_meta_pkgs: List[FP8MetaPackage], - layernorm_type: str, - zero_centered_gamma: bool = False, - epsilon: float = 1e-6, - layernorm_input_axes: Tuple[str, ...] = None, - dot_1_input_axes: Tuple[str, ...] = None, - dot_2_input_axes: Tuple[str, ...] = None, - ffn1_ckpt_name: str = 'ffn1', - ffn2_ckpt_name: str = 'ffn2', - activation_type: Sequence[Union[str, Callable]] = ('gelu',), - use_bias: bool = True) -> jnp.ndarray: +def fused_layernorm_fp8_mlp( + x: jnp.ndarray, + gamma: jnp.ndarray, + beta: jnp.ndarray, + kernels: List[jnp.ndarray], + biases: List[jnp.ndarray], + fp8_meta_pkgs: List[FP8MetaPackage], + layernorm_type: str, + zero_centered_gamma: bool = False, + epsilon: float = 1e-6, + layernorm_input_axes: Tuple[str, ...] = None, + dot_1_input_axes: Tuple[str, ...] = None, + dot_2_input_axes: Tuple[str, ...] = None, + ffn1_ckpt_name: str = "ffn1", + ffn2_ckpt_name: str = "ffn2", + activation_type: Sequence[Union[str, Callable]] = ("gelu",), + use_bias: bool = True, +) -> jnp.ndarray: """ Layernorm + GEMM1 + bias + activation + GEMM2 + bias """ @@ -88,40 +90,67 @@ def fused_layernorm_fp8_mlp(x: jnp.ndarray, bwd_dtype = FP8Helper.BWD_DTYPE layernorm_type = canonicalize_layernorm_type(layernorm_type) - if layernorm_type == 'rmsnorm': + if layernorm_type == "rmsnorm": assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - - output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, - amax_list_1, amax_list_2, scale_list_1, scale_list_2, - fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, - epsilon, layernorm_input_axes, dot_1_input_axes, - dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, - activation_type, use_bias) + assert ( + not zero_centered_gamma + ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'" + + output = _fused_layernorm_fp8_mlp( + x, + gamma, + beta, + kernel_1, + kernel_2, + bias_1, + bias_2, + amax_list_1, + amax_list_2, + scale_list_1, + scale_list_2, + fwd_dtype, + bwd_dtype, + layernorm_type, + zero_centered_gamma, + epsilon, + layernorm_input_axes, + dot_1_input_axes, + dot_2_input_axes, + ffn1_ckpt_name, + ffn2_ckpt_name, + activation_type, + use_bias, + ) return output @partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)) -def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, - kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, - bias_2: jnp.ndarray, amax_list_1: List[jnp.ndarray], - amax_list_2: List[jnp.ndarray], scale_list_1: List[jnp.ndarray], - scale_list_2: List[jnp.ndarray], fwd_dtype: jnp.dtype, - bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool, - epsilon: float, layernorm_input_axes: Tuple[str, ...], - dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], - ffn1_ckpt_name: str, ffn2_ckpt_name: str, - activation_type: Sequence[Union[str, Callable]], use_bias: bool): +def _fused_layernorm_fp8_mlp( + x: jnp.ndarray, + gamma: jnp.ndarray, + beta: jnp.ndarray, + kernel_1: jnp.ndarray, + kernel_2: jnp.ndarray, + bias_1: jnp.ndarray, + bias_2: jnp.ndarray, + amax_list_1: List[jnp.ndarray], + amax_list_2: List[jnp.ndarray], + scale_list_1: List[jnp.ndarray], + scale_list_2: List[jnp.ndarray], + fwd_dtype: jnp.dtype, + bwd_dtype: jnp.dtype, + layernorm_type: str, + zero_centered_gamma: bool, + epsilon: float, + layernorm_input_axes: Tuple[str, ...], + dot_1_input_axes: Tuple[str, ...], + dot_2_input_axes: Tuple[str, ...], + ffn1_ckpt_name: str, + ffn2_ckpt_name: str, + activation_type: Sequence[Union[str, Callable]], + use_bias: bool, +): output, _ = _fused_layernorm_fp8_mlp_fwd_rule( - x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, amax_list_1, amax_list_2, scale_list_1, - scale_list_2, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, - layernorm_input_axes, dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, - activation_type, use_bias) - return output - - -def _fused_layernorm_fp8_mlp_fwd_rule( x, gamma, beta, @@ -134,7 +163,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule( scale_list_1, scale_list_2, fwd_dtype, - bwd_dtype, # pylint: disable=unused-argument + bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, @@ -144,7 +173,36 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - use_bias): + use_bias, + ) + return output + + +def _fused_layernorm_fp8_mlp_fwd_rule( + x, + gamma, + beta, + kernel_1, + kernel_2, + bias_1, + bias_2, + amax_list_1, + amax_list_2, + scale_list_1, + scale_list_2, + fwd_dtype, + bwd_dtype, # pylint: disable=unused-argument + layernorm_type, + zero_centered_gamma, + epsilon, + layernorm_input_axes, + dot_1_input_axes, + dot_2_input_axes, + ffn1_ckpt_name, + ffn2_ckpt_name, + activation_type, + use_bias, +): # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out) @@ -159,20 +217,22 @@ def _fused_layernorm_fp8_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0] assert kernel_1.shape[-1] == kernel_2.shape[0] - maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \ - FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list_1, *scale_list_1, - *amax_list_2, *scale_list_2) + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( + *amax_list_1, *scale_list_1, *amax_list_2, *scale_list_2 + ) amax_list_1 = maybe_fm32_to_fp32(*amax_list_1) scale_list_1 = maybe_fm32_to_fp32(*scale_list_1) amax_list_2 = maybe_fm32_to_fp32(*amax_list_2) scale_list_2 = maybe_fm32_to_fp32(*scale_list_2) fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype] - scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale(amax_list_1, scale_list_1, - fp8_dtype_list) + scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale( + amax_list_1, scale_list_1, fp8_dtype_list + ) amax_list_1 = FP8MetaPackage.update_amax_list(amax_list_1) - scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(amax_list_2, scale_list_2, - fp8_dtype_list) + scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale( + amax_list_2, scale_list_2, fp8_dtype_list + ) amax_list_2 = FP8MetaPackage.update_amax_list(amax_list_2) x_amax = amax_list_1[FP8MetaPackage.INPUT_IDX][0:1] @@ -181,7 +241,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule( x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) - if layernorm_type == 'layernorm': + if layernorm_type == "layernorm": ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8( x, gamma, @@ -191,17 +251,15 @@ def _fused_layernorm_fp8_mlp_fwd_rule( x_scale_inv, out_dtype=fwd_dtype, zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + epsilon=epsilon, + ) else: - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(x, - gamma, - x_amax, - x_scale, - x_scale_inv, - out_dtype=fwd_dtype, - epsilon=epsilon) + assert ( + not zero_centered_gamma + ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'" + ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8( + x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon + ) mu = None assert x.shape == ln_out.shape @@ -212,15 +270,22 @@ def _fused_layernorm_fp8_mlp_fwd_rule( # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding # unnecessary copy to break FP8 GEMM pattern matching. - casted_kernel_1, updated_kernel_1_amax = \ - tex.cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype) + casted_kernel_1, updated_kernel_1_amax = tex.cast_fp8( + kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype + ) ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes) # (batch..., hidden_in) x (hidden_in, hidden_out) - dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype, - (x_contracting_dims, (0,)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) + dot_1_output = fp8_dot_impl( + ln_out, + casted_kernel_1, + x_scale_inv, + kernel_1_scale_inv, + x.dtype, + (x_contracting_dims, (0,)), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP), + ) if use_bias: bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape @@ -234,12 +299,18 @@ def _fused_layernorm_fp8_mlp_fwd_rule( activation_lu_out_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX] # (batch..., hidden_in) -> (batch..., hidden) - casted_activation_lu_out, updated_activation_lu_amax = \ - tex.act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale, - activation_lu_out_scale_inv, fwd_dtype, activation_type) + casted_activation_lu_out, updated_activation_lu_amax = tex.act_lu_fp8( + dot_1_output, + activation_lu_out_amax, + activation_lu_out_scale, + activation_lu_out_scale_inv, + fwd_dtype, + activation_type, + ) casted_activation_lu_out = with_sharding_constraint_by_logical_axes( - casted_activation_lu_out, dot_2_input_axes) + casted_activation_lu_out, dot_2_input_axes + ) kernel_2_scale = scale_list_2[FP8MetaPackage.WEIGHT_IDX] kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX] @@ -248,10 +319,15 @@ def _fused_layernorm_fp8_mlp_fwd_rule( casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale) # (batch..., hidden_in) x (hidden_out, hidden_in) - dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2, - activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype, - (x_contracting_dims, (0,)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) + dot_2_output = fp8_dot_impl( + casted_activation_lu_out, + casted_kernel_2, + activation_lu_out_scale_inv, + kernel_2_scale_inv, + x.dtype, + (x_contracting_dims, (0,)), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP), + ) if use_bias: bias_2_shape = bias_2.shape @@ -262,35 +338,78 @@ def _fused_layernorm_fp8_mlp_fwd_rule( dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) - ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1, - casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, scale_inv_list_1, - scale_inv_list_2, updated_x_amax, updated_activation_lu_amax, updated_kernel_1_amax, - updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, - maybe_fp32_to_fm32) + ctx = ( + x, + ln_out, + mu, + rsigma, + gamma, + dot_1_output, + casted_activation_lu_out, + casted_kernel_1, + casted_kernel_2, + amax_list_1, + amax_list_2, + scale_list_1, + scale_list_2, + scale_inv_list_1, + scale_inv_list_2, + updated_x_amax, + updated_activation_lu_amax, + updated_kernel_1_amax, + updated_kernel_2_amax, + x_contracting_dims, + xt_batch_dims, + bias_1_shape, + bias_2_shape, + maybe_fp32_to_fm32, + ) return dot_2_output, ctx def _fused_layernorm_fp8_mlp_bwd_rule( - fwd_dtype, # pylint: disable=unused-argument - bwd_dtype, - layernorm_type, - zero_centered_gamma, - epsilon, - layernorm_input_axes, - dot_1_input_axes, - dot_2_input_axes, - ffn1_ckpt_name, # pylint: disable=unused-argument - ffn2_ckpt_name, # pylint: disable=unused-argument - activation_type, - use_bias, - ctx, - grad): - x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \ - casted_kernel_1, casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, \ - scale_inv_list_1, scale_inv_list_2, updated_x_amax, \ - updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ - x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx + fwd_dtype, # pylint: disable=unused-argument + bwd_dtype, + layernorm_type, + zero_centered_gamma, + epsilon, + layernorm_input_axes, + dot_1_input_axes, + dot_2_input_axes, + ffn1_ckpt_name, # pylint: disable=unused-argument + ffn2_ckpt_name, # pylint: disable=unused-argument + activation_type, + use_bias, + ctx, + grad, +): + ( + x, + ln_out, + mu, + rsigma, + gamma, + dot_1_output, + casted_activation_lu_out, + casted_kernel_1, + casted_kernel_2, + amax_list_1, + amax_list_2, + scale_list_1, + scale_list_2, + scale_inv_list_1, + scale_inv_list_2, + updated_x_amax, + updated_activation_lu_amax, + updated_kernel_1_amax, + updated_kernel_2_amax, + x_contracting_dims, + xt_batch_dims, + bias_1_shape, + bias_2_shape, + maybe_fp32_to_fm32, + ) = ctx grad_amax = amax_list_2[FP8MetaPackage.GRAD_IDX][0:1] grad_scale = scale_list_2[FP8MetaPackage.GRAD_IDX] @@ -299,35 +418,55 @@ def _fused_layernorm_fp8_mlp_bwd_rule( # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) if use_bias: - casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \ - tex.dbias_cast_transpose(grad, grad_amax, grad_scale, - grad_scale_inv, bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1) + casted_grad, casted_grad_t, dbias_2, updated_grad_amax = tex.dbias_cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) dbias_2 = jnp.reshape(dbias_2, bias_2_shape) else: - casted_grad, casted_grad_t, updated_grad_amax = \ - tex.cast_transpose(grad, grad_amax, grad_scale, - grad_scale_inv, bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1) + casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) dbias_2 = None - casted_activation_lu_out_t = tex.transpose(casted_activation_lu_out, - static_axis_boundary=-1, - transpose_axis_boundary=-1) + casted_activation_lu_out_t = tex.transpose( + casted_activation_lu_out, static_axis_boundary=-1, transpose_axis_boundary=-1 + ) # (hidden, batch...,) x (hidden, batch...) gemm2_x_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX] - wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv, - grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) + wgrad_2 = fp8_dot_impl( + casted_activation_lu_out_t, + casted_grad_t, + gemm2_x_scale_inv, + grad_scale_inv, + grad.dtype, + (xt_batch_dims, xt_batch_dims), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD), + ) # (batch..., hidden_out) x (hidden_in, hidden_out) kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX] - dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv, - grad.dtype, (x_contracting_dims, (1,)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) + dgrad_2 = fp8_dot_impl( + casted_grad, + casted_kernel_2, + grad_scale_inv, + kernel_2_scale_inv, + grad.dtype, + (x_contracting_dims, (1,)), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD), + ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -335,56 +474,64 @@ def _fused_layernorm_fp8_mlp_bwd_rule( dactivation_lu_scale = scale_list_1[FP8MetaPackage.GRAD_IDX] dactivation_lu_scale_inv = scale_inv_list_1[FP8MetaPackage.GRAD_IDX] - if len(activation_type) > 1: # if gated + if len(activation_type) > 1: # if gated if use_bias: dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type) - casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ - tex.dbias_cast_transpose( - dactivation_lu, - dactivation_lu_amax, - dactivation_lu_scale, - dactivation_lu_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-2) + casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = ( + tex.dbias_cast_transpose( + dactivation_lu, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-2, + ) + ) dbias_1 = jnp.reshape(dbias_1, bias_1_shape) else: - casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ - tex.dgated_act_lu_cast_transpose( - dgrad_2, - dot_1_output, - dactivation_lu_amax, - dactivation_lu_scale, - dactivation_lu_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - activation_type=activation_type) + casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = ( + tex.dgated_act_lu_cast_transpose( + dgrad_2, + dot_1_output, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + activation_type=activation_type, + ) + ) dbias_1 = None else: if use_bias: - casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\ - tex.dact_lu_dbias_cast_transpose( - dgrad_2, - dot_1_output, - dactivation_lu_amax, - dactivation_lu_scale, - dactivation_lu_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-2, - activation_type=activation_type) + casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = ( + tex.dact_lu_dbias_cast_transpose( + dgrad_2, + dot_1_output, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-2, + activation_type=activation_type, + ) + ) dbias_1 = jnp.reshape(dbias_1, bias_1_shape) else: dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type) - casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ - tex.cast_transpose( - dactivation_lu, - dactivation_lu_amax, - dactivation_lu_scale, - dactivation_lu_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-2) + casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = ( + tex.cast_transpose( + dactivation_lu, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-2, + ) + ) dbias_1 = None ln_out_t = tex.transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) @@ -392,54 +539,83 @@ def _fused_layernorm_fp8_mlp_bwd_rule( # (hidden, batch...) x (hidden, batch...) gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX] xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims) - wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, - dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) - - x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims), - (1, 2)) + wgrad_1 = fp8_dot_impl( + ln_out_t, + casted_dactivation_lu_t, + gemm1_x_scale_inv, + dactivation_lu_scale_inv, + grad.dtype, + (xt_batch_dims, xt_batch_dims_2), + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD), + ) + + x_contracting_dims = ( + (min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims), + (1, 2), + ) kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX] - dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv, - kernel_1_scale_inv, grad.dtype, x_contracting_dims, - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) + dgrad_1 = fp8_dot_impl( + casted_dactivation_lu, + casted_kernel_1, + dactivation_lu_scale_inv, + kernel_1_scale_inv, + grad.dtype, + x_contracting_dims, + get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD), + ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes) - if layernorm_type == 'layernorm': - dx, dgamma, dbeta = tex.layernorm_bwd(dgrad_1, - x, - mu, - rsigma, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + if layernorm_type == "layernorm": + dx, dgamma, dbeta = tex.layernorm_bwd( + dgrad_1, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + ) else: - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" + assert ( + not zero_centered_gamma + ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'" dx, dgamma = tex.rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon) dbeta = None - amax_list_1[FP8MetaPackage.INPUT_IDX] = \ + amax_list_1[FP8MetaPackage.INPUT_IDX] = ( amax_list_1[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) - amax_list_1[FP8MetaPackage.WEIGHT_IDX] = \ + ) + amax_list_1[FP8MetaPackage.WEIGHT_IDX] = ( amax_list_1[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_1_amax[0]) - amax_list_1[FP8MetaPackage.GRAD_IDX] = \ + ) + amax_list_1[FP8MetaPackage.GRAD_IDX] = ( amax_list_1[FP8MetaPackage.GRAD_IDX].at[0].set(updated_dactivation_lu_amax[0]) - amax_list_2[FP8MetaPackage.INPUT_IDX] = \ + ) + amax_list_2[FP8MetaPackage.INPUT_IDX] = ( amax_list_2[FP8MetaPackage.INPUT_IDX].at[0].set(updated_activation_lu_amax[0]) - amax_list_2[FP8MetaPackage.WEIGHT_IDX] = \ + ) + amax_list_2[FP8MetaPackage.WEIGHT_IDX] = ( amax_list_2[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_2_amax) - amax_list_2[FP8MetaPackage.GRAD_IDX] = \ + ) + amax_list_2[FP8MetaPackage.GRAD_IDX] = ( amax_list_2[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) + ) amax_list_1 = maybe_fp32_to_fm32(*amax_list_1) scale_list_1 = maybe_fp32_to_fm32(*scale_list_1) amax_list_2 = maybe_fp32_to_fm32(*amax_list_2) scale_list_2 = maybe_fp32_to_fm32(*scale_list_2) - return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \ - amax_list_1, amax_list_2, scale_list_1, scale_list_2 + return ( + dx, + dgamma, + dbeta, + wgrad_1, + wgrad_2, + dbias_1, + dbias_2, + amax_list_1, + amax_list_2, + scale_list_1, + scale_list_2, + ) -_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule, - _fused_layernorm_fp8_mlp_bwd_rule) +_fused_layernorm_fp8_mlp.defvjp( + _fused_layernorm_fp8_mlp_fwd_rule, _fused_layernorm_fp8_mlp_bwd_rule +) diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index 26c45cbed6f226958a9c4a10f39f45e50ed28313..b82c0915e4a33d27fe097a4b7dd5da93dcbbab2b 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -49,17 +49,19 @@ class TransformerEngineBaseLayer(BaseLayer): FP8Helper.FP8_COLLECTION_NAME: [ WeightHParamsCollection.SKIP_LP_REGULARIZATION, WeightHParamsCollection.OVERWRITE_WITH_GRADIENT, - WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION + WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION, ] } - flax_module_p = pax_fiddle.Config(flax_adapter.FlaxModuleAdapter, - module_factory_method=flax_module_cls, - logical_axes_rules=self.logical_axes_rules, - var_collection_map=fp8_collection_map, - ici_mesh_shape=self.ici_mesh_shape, - dcn_mesh_shape=self.dcn_mesh_shape, - mesh_axis_names=self.mesh_axis_names) + flax_module_p = pax_fiddle.Config( + flax_adapter.FlaxModuleAdapter, + module_factory_method=flax_module_cls, + logical_axes_rules=self.logical_axes_rules, + var_collection_map=fp8_collection_map, + ici_mesh_shape=self.ici_mesh_shape, + dcn_mesh_shape=self.dcn_mesh_shape, + mesh_axis_names=self.mesh_axis_names, + ) self.create_child(name, flax_module_p.clone()) @@ -68,7 +70,7 @@ class LayerNorm(TransformerEngineBaseLayer): """LayerNorm""" epsilon: float = 1e-6 - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () @@ -80,17 +82,18 @@ class LayerNorm(TransformerEngineBaseLayer): """setup""" super().setup() - ln_cls = partial(flax_LayerNorm, - epsilon=self.epsilon, - layernorm_type=self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - scale_init=_generate_ln_scale_init(self.scale_init), - scale_axes=self.scale_axes, - bias_init=TransformerEngineBaseLayer.generate_params_init( - "ln_bias", self.bias_init), - bias_axes=self.bias_axes, - dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence) + ln_cls = partial( + flax_LayerNorm, + epsilon=self.epsilon, + layernorm_type=self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + scale_init=_generate_ln_scale_init(self.scale_init), + scale_axes=self.scale_axes, + bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", self.bias_init), + bias_axes=self.bias_axes, + dtype=self.dtype, + transpose_batch_sequence=self.transpose_batch_sequence, + ) self.create_layer("layer_norm", ln_cls) @@ -109,9 +112,9 @@ class FusedSoftmax(TransformerEngineBaseLayer): """setup""" super().setup() - fused_softmax_cls = partial(Softmax, - scale_factor=self.scale_factor, - softmax_type=self.softmax_type) + fused_softmax_cls = partial( + Softmax, scale_factor=self.scale_factor, softmax_type=self.softmax_type + ) self.create_layer("fused_softmax", fused_softmax_cls) @@ -151,7 +154,8 @@ class Linear(TransformerEngineBaseLayer): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, axis=self.axis, dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence) + transpose_batch_sequence=self.transpose_batch_sequence, + ) self.create_layer("linear", dense_general_cls) @@ -165,7 +169,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): out_features: int = 512 enable_layernorm: bool = True - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" epsilon: float = 1e-6 zero_centered_gamma: bool = False scale_init: WeightInit = None @@ -198,7 +202,8 @@ class LayerNormLinear(TransformerEngineBaseLayer): scale_init=_generate_ln_scale_init(self.scale_init), scale_axes=self.scale_axes, ln_bias_init=TransformerEngineBaseLayer.generate_params_init( - "ln_bias", self.ln_bias_init), + "ln_bias", self.ln_bias_init + ), ln_bias_axes=self.ln_bias_axes, kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), kernel_axes=self.kernel_axes, @@ -212,7 +217,8 @@ class LayerNormLinear(TransformerEngineBaseLayer): axis=self.axis, dtype=self.dtype, transpose_batch_sequence=self.transpose_batch_sequence, - depth_scaling=self.depth_scaling) + depth_scaling=self.depth_scaling, + ) self.create_layer("ln_linear", ln_dense_general_cls) @@ -226,7 +232,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): intermediate_dim: int = 2048 enable_layernorm: bool = True - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" epsilon: float = 1e-6 zero_centered_gamma: bool = False scale_init: WeightInit = None @@ -243,7 +249,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = ('relu',) + activations: Sequence[Union[str, Callable]] = ("relu",) intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () axis: Union[Iterable[int], int] = -1 @@ -263,7 +269,8 @@ class LayerNormMLP(TransformerEngineBaseLayer): scale_init=_generate_ln_scale_init(self.scale_init), scale_axes=self.scale_axes, ln_bias_init=TransformerEngineBaseLayer.generate_params_init( - "ln_bias", self.ln_bias_init), + "ln_bias", self.ln_bias_init + ), ln_bias_axes=self.ln_bias_axes, kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), kernel_axes_1=self.kernel_axes_1, @@ -281,7 +288,8 @@ class LayerNormMLP(TransformerEngineBaseLayer): intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims, axis=self.axis, dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence) + transpose_batch_sequence=self.transpose_batch_sequence, + ) self.create_layer("ln_mlp", ln_mlp_cls) diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index ede101ad2a7440a4fcd0452fe180dd2c73462d5d..2651144eeee35386959ea3f06f51a2fd89c5e55e 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -35,7 +35,7 @@ class RelativePositionBiases(TransformerEngineBaseLayer): """generate_embedding_init""" embedding_init = init if embedding_init is None: - rb_stddev = (num_attention_heads * num_buckets)**-0.5 + rb_stddev = (num_attention_heads * num_buckets) ** -0.5 embedding_init = WeightInit.Gaussian(rb_stddev) return embedding_init @@ -44,16 +44,20 @@ class RelativePositionBiases(TransformerEngineBaseLayer): super().setup() embedding_init = RelativePositionBiases.generate_embedding_init( - self.embedding_init, self.num_attention_heads, self.num_buckets) + self.embedding_init, self.num_attention_heads, self.num_buckets + ) - rpb_cls = partial(flax_RelativePositionBiases, - num_buckets=self.num_buckets, - max_distance=self.max_distance, - num_attention_heads=self.num_attention_heads, - embedding_init=TransformerEngineBaseLayer.generate_params_init( - "rel_embedding", embedding_init), - embedding_axes=self.embedding_axes, - dtype=self.dtype) + rpb_cls = partial( + flax_RelativePositionBiases, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + num_attention_heads=self.num_attention_heads, + embedding_init=TransformerEngineBaseLayer.generate_params_init( + "rel_embedding", embedding_init + ), + embedding_axes=self.embedding_axes, + dtype=self.dtype, + ) self.create_layer("relative_position_bias", rpb_cls) @@ -68,12 +72,12 @@ class DotProductAttention(TransformerEngineBaseLayer): head_dim: int = 0 num_attention_heads: int = 0 num_gqa_groups: Optional[int] = None - attention_dropout: float = 0. - attn_mask_type: AttnMaskType = 'causal' + attention_dropout: float = 0.0 + attn_mask_type: AttnMaskType = "causal" attn_bias_type: AttnBiasType = None - dropout_rng_name: str = 'dropout' + dropout_rng_name: str = "dropout" float32_logits: bool = False - qkv_layout: str = 'bshd_bshd_bshd' + qkv_layout: str = "bshd_bshd_bshd" scale_factor: Optional[float] = None transpose_batch_sequence: bool = True @@ -81,40 +85,41 @@ class DotProductAttention(TransformerEngineBaseLayer): """setup""" super().setup() - assert self.head_dim > 0, f'{self.head_dim=}' - assert self.num_attention_heads > 0, f'{self.num_attention_heads=}' - - dpa_cls = partial(flax_DotProductAttention, - head_dim=self.head_dim, - num_attention_heads=self.num_attention_heads, - num_gqa_groups=self.num_gqa_groups, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - attention_dropout=self.attention_dropout, - dtype=self.dtype, - dropout_rng_name=self.dropout_rng_name, - float32_logits=self.float32_logits, - qkv_layout=self.qkv_layout, - scale_factor=self.scale_factor, - transpose_batch_sequence=self.transpose_batch_sequence) + assert self.head_dim > 0, f"{self.head_dim=}" + assert self.num_attention_heads > 0, f"{self.num_attention_heads=}" + + dpa_cls = partial( + flax_DotProductAttention, + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_gqa_groups=self.num_gqa_groups, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + attention_dropout=self.attention_dropout, + dtype=self.dtype, + dropout_rng_name=self.dropout_rng_name, + float32_logits=self.float32_logits, + qkv_layout=self.qkv_layout, + scale_factor=self.scale_factor, + transpose_batch_sequence=self.transpose_batch_sequence, + ) self.create_layer("dot_product_attention", dpa_cls) - def __call__(self, - query: JTensor, - key: JTensor, - value: JTensor, - mask: Optional[JTensor] = None, - bias: Optional[JTensor] = None, - *, - deterministic: bool = False) -> JTensor: + def __call__( + self, + query: JTensor, + key: JTensor, + value: JTensor, + mask: Optional[JTensor] = None, + bias: Optional[JTensor] = None, + *, + deterministic: bool = False, + ) -> JTensor: """__call__""" - return self.dot_product_attention(query, - key, - value, - mask, - bias, - deterministic=deterministic) + return self.dot_product_attention( + query, key, value, mask, bias, deterministic=deterministic + ) class MultiHeadAttention(TransformerEngineBaseLayer): @@ -123,8 +128,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer): head_dim: int = 0 num_attention_heads: int = 0 num_gqa_groups: Optional[int] = None - attention_dropout: float = 0. - dropout_rng_name: str = 'dropout' + attention_dropout: float = 0.0 + dropout_rng_name: str = "dropout" input_layernorm: bool = True layernorm_type: str = "layernorm" layernorm_epsilon: float = 1e-6 @@ -132,12 +137,12 @@ class MultiHeadAttention(TransformerEngineBaseLayer): return_layernorm_output: bool = False use_bias: bool = False bias_init: WeightInit = WeightInit.Constant(0.0) - attn_mask_type: str = 'causal' + attn_mask_type: str = "causal" attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) - rotary_pos_emb_group_method: str = 'consecutive' - low_rank_adaptation_scope: str = 'none' + rotary_pos_emb_group_method: str = "consecutive" + low_rank_adaptation_scope: str = "none" low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None fuse_qkv_params: bool = True @@ -160,24 +165,32 @@ class MultiHeadAttention(TransformerEngineBaseLayer): self.num_attention_heads = self.num_heads warnings.warn( f"{__class__}.num_heads is deprecated. It will be removed recently. " - f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning) + f"Please uses {__class__}.num_attention_heads as the new API.", + DeprecationWarning, + ) if self.dropout_rate is not None: self.attention_dropout = self.dropout_rate warnings.warn( f"{__class__}.dropout_rate is deprecated. It will be removed recently. " - f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning) + f"Please use {__class__}.attention_dropout as the new API.", + DeprecationWarning, + ) if self.apply_residual_connection_post_layernorm is not None: warnings.warn( f"{__class__}.apply_residual_connection_post_layernorm is deprecated. " f"It will be removed recently, please use {__class__}.return_layernorm_output.", - DeprecationWarning) + DeprecationWarning, + ) if self.fuse_qkv is not None: warnings.warn( f"{__class__}.fuse_qkv is deprecated. It will be removed recently. " - f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning) + f"Please use {__class__}.fuse_qkv_params as the new API.", + DeprecationWarning, + ) assert self.output_layernorm is None, ( f"{__class__}.output_layernorm is deprecated. It will be removed recently. " - f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.") + f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm." + ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_heads @@ -187,8 +200,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer): """setup""" super().setup() - assert self.head_dim > 0, f'{self.head_dim=}' - assert self.num_attention_heads > 0, f'{self.num_attention_heads=}' + assert self.head_dim > 0, f"{self.head_dim=}" + assert self.num_attention_heads > 0, f"{self.num_attention_heads=}" mha_cls = partial( flax_MultiHeadAttention, @@ -219,25 +232,25 @@ class MultiHeadAttention(TransformerEngineBaseLayer): enable_sequence_parallel=self.enable_sequence_parallel, scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, - float32_logits=self.float32_logits) + float32_logits=self.float32_logits, + ) self.create_layer("multi_head_attn", mha_cls) - def __call__(self, - inputs_q: JTensor, - inputs_kv: JTensor, - mask: Optional[JTensor] = None, - bias: Optional[JTensor] = None, - *, - decode: bool = False, - deterministic: bool = False) -> JTensor: + def __call__( + self, + inputs_q: JTensor, + inputs_kv: JTensor, + mask: Optional[JTensor] = None, + bias: Optional[JTensor] = None, + *, + decode: bool = False, + deterministic: bool = False, + ) -> JTensor: """__call__""" - return self.multi_head_attn(inputs_q, - inputs_kv, - mask, - bias, - decode=decode, - deterministic=deterministic) + return self.multi_head_attn( + inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic + ) class TransformerLayer(TransformerEngineBaseLayer): @@ -247,7 +260,7 @@ class TransformerLayer(TransformerEngineBaseLayer): mlp_hidden_size: int = 2048 num_attention_heads: int = 8 num_gqa_groups: Optional[int] = None - layernorm_type: str = 'layernorm' + layernorm_type: str = "layernorm" layernorm_epsilon: float = 1e-6 zero_centered_gamma: bool = False hidden_dropout: float = 0.1 @@ -255,20 +268,20 @@ class TransformerLayer(TransformerEngineBaseLayer): attention_dropout: float = 0.1 intermediate_dropout: float = 0.1 intermediate_dropout_dims: Sequence[int] = () - dropout_rng_name: str = 'dropout' - mlp_activations: Sequence[str] = ('relu',) + dropout_rng_name: str = "dropout" + mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False bias_init: WeightInit = WeightInit.Constant(0.0) apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False float32_attention_logits: bool = False layer_type: TransformerLayerType = TransformerLayerType.ENCODER - self_attn_mask_type: str = 'causal' + self_attn_mask_type: str = "causal" self_attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) - rotary_pos_emb_group_method: str = 'consecutive' - low_rank_adaptation_scope: str = 'none' + rotary_pos_emb_group_method: str = "consecutive" + low_rank_adaptation_scope: str = "none" low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None enable_relative_embedding: bool = True @@ -291,23 +304,27 @@ class TransformerLayer(TransformerEngineBaseLayer): relative_embedding_flax_module = None if self.enable_relative_embedding and self.relative_embedding is not None: - assert self.relative_embedding.num_attention_heads == \ - self.num_attention_heads, \ - "TransformerLayer.relative_embedding.num_attention_heads shoule be" \ + assert self.relative_embedding.num_attention_heads == self.num_attention_heads, ( + "TransformerLayer.relative_embedding.num_attention_heads shoule be" "the same as TransformerLayer.num_attention_heads." + ) embedding_init = RelativePositionBiases.generate_embedding_init( - self.relative_embedding.embedding_init, self.relative_embedding.num_attention_heads, - self.relative_embedding.num_buckets) + self.relative_embedding.embedding_init, + self.relative_embedding.num_attention_heads, + self.relative_embedding.num_buckets, + ) relative_embedding_flax_module = flax_RelativePositionBiases( num_buckets=self.relative_embedding.num_buckets, max_distance=self.relative_embedding.max_distance, num_attention_heads=self.relative_embedding.num_attention_heads, embedding_init=TransformerEngineBaseLayer.generate_params_init( - "rel_embedding", embedding_init), + "rel_embedding", embedding_init + ), embedding_axes=self.relative_embedding.embedding_axes, - dtype=self.relative_embedding.dtype) + dtype=self.relative_embedding.dtype, + ) transformerlayer_cls = partial( flax_TransformerLayer, @@ -326,9 +343,11 @@ class TransformerLayer(TransformerEngineBaseLayer): intermediate_dropout_dims=self.intermediate_dropout_dims, dropout_rng_name=self.dropout_rng_name, mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( - "mha_kernel", self.params_init), + "mha_kernel", self.params_init + ), mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init( - "mlp_kernel", self.params_init), + "mlp_kernel", self.params_init + ), mlp_activations=self.mlp_activations, use_bias=self.use_bias, bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), @@ -351,18 +370,28 @@ class TransformerLayer(TransformerEngineBaseLayer): transpose_batch_sequence=self.transpose_batch_sequence, enable_sequence_parallel=self.enable_sequence_parallel, scale_attn_logits=self.scale_attn_logits, - scaled_query_init=self.scaled_query_init) + scaled_query_init=self.scaled_query_init, + ) self.create_layer("transformerlayer", transformerlayer_cls) - def __call__(self, - inputs: JTensor, - encoded: JTensor = None, - attention_mask: JTensor = None, - encoder_decoder_mask: JTensor = None, - deterministic: bool = False, - decode: bool = False, - max_decode_length: bool = None) -> JTensor: + def __call__( + self, + inputs: JTensor, + encoded: JTensor = None, + attention_mask: JTensor = None, + encoder_decoder_mask: JTensor = None, + deterministic: bool = False, + decode: bool = False, + max_decode_length: bool = None, + ) -> JTensor: """__call__""" - return self.transformerlayer(inputs, encoded, attention_mask, encoder_decoder_mask, - deterministic, decode, max_decode_length) + return self.transformerlayer( + inputs, + encoded, + attention_mask, + encoder_decoder_mask, + deterministic, + decode, + max_decode_length, + ) diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 2d5786465e3ca186a9dc00ea1ecaf4668f3b8b9d..47cbfd958e0284e803bf174a848d82d6cedf68bb 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -30,7 +30,7 @@ from build_tools.utils import package_files, copy_common_headers, install_and_im from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension -install_and_import('pybind11') +install_and_import("pybind11") from pybind11.setup_helpers import build_ext as BuildExtension CMakeBuildExtension = get_build_ext(BuildExtension) @@ -39,12 +39,12 @@ CMakeBuildExtension = get_build_ext(BuildExtension) if __name__ == "__main__": # Extensions common_headers_dir = "common_headers" - copy_common_headers( - current_file_path.parent, - str(current_file_path / common_headers_dir)) + copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) ext_modules = [ setup_jax_extension( - "csrc", current_file_path / "csrc", current_file_path / common_headers_dir)] + "csrc", current_file_path / "csrc", current_file_path / common_headers_dir + ) + ] # Configure package setuptools.setup( @@ -57,9 +57,11 @@ if __name__ == "__main__": install_requires=["jax", "flax>=0.7.1"], tests_require=["numpy", "praxis"], include_package_data=True, - package_data={"csrc": package_files("csrc"), - common_headers_dir: package_files(common_headers_dir), - "build_tools": package_files("build_tools")}, + package_data={ + "csrc": package_files("csrc"), + common_headers_dir: package_files(common_headers_dir), + "build_tools": package_files("build_tools"), + }, ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 205ac56bd5815397557d57ea9ccc45cf310a3c6d..c0b60fe61eb91ad147db88f22658c731ad676d61 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -17,23 +17,22 @@ from jax.sharding import PartitionSpec _PXLA_THREAD_RESOURCES = pxla.thread_resources # Axis Names -BATCH_AXES = 'nvte_batch' -SEQLEN_AXES = 'nvte_seqlen' -SEQLEN_TP_AXES = 'nvte_seqlen_tp' -HEAD_AXES = 'nvte_head' -HIDDEN_AXES = 'nvte_hidden' -HIDDEN_TP_AXES = 'nvte_hidden_tp' -JOINED_AXES = 'nvte_joined' -W_NO_SHARD_AXES = 'nvte_w_no_shard' -W_FSDP_AXES = 'nvte_w_fsdp' -W_TP_AXES = 'nvte_w_tp' -W_JOINED_AXES = 'nvte_w_joined' +BATCH_AXES = "nvte_batch" +SEQLEN_AXES = "nvte_seqlen" +SEQLEN_TP_AXES = "nvte_seqlen_tp" +HEAD_AXES = "nvte_head" +HIDDEN_AXES = "nvte_hidden" +HIDDEN_TP_AXES = "nvte_hidden_tp" +JOINED_AXES = "nvte_joined" +W_NO_SHARD_AXES = "nvte_w_no_shard" +W_FSDP_AXES = "nvte_w_fsdp" +W_TP_AXES = "nvte_w_tp" +W_JOINED_AXES = "nvte_w_joined" def _get_mesh_info(resource: str): mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh - assert resource in mesh.axis_names, \ - f"{resource} is not in the axis_names of Mesh {mesh}." + assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}." return mesh.shape[resource], resource @@ -45,8 +44,11 @@ def get_sharding_map_logic_axis_to_mesh_axis(): IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False))) - batch_resources = [gsr.fsdp_resource, gsr.dp_resource] if IS_FSDP_OUTER \ - else [gsr.dp_resource, gsr.fsdp_resource] + batch_resources = ( + [gsr.fsdp_resource, gsr.dp_resource] + if IS_FSDP_OUTER + else [gsr.dp_resource, gsr.fsdp_resource] + ) batch_dim_rule = [] for resource in batch_resources: @@ -168,6 +170,7 @@ class MeshResource: The axis name in Mesh used to split model layers. along. If it is None, then pipeline parallelism is disabled. """ + dp_resource: str = None tp_resource: str = None fsdp_resource: str = None @@ -240,6 +243,7 @@ class MajorShardingType(Enum): DPTP: Data and Standard tensor parallel training. """ + SINGLE = 0 DP = 1 TP = 2 @@ -267,6 +271,7 @@ class ShardingType(Enum): DP_TP_ROW: Sharding along data and row-split tensor parallelism. """ + SINGLE = (MajorShardingType.SINGLE, "single") DP = (MajorShardingType.DP, "dp") TP_COL = (MajorShardingType.TP, "tp_col") diff --git a/transformer_engine/jax/softmax.py b/transformer_engine/jax/softmax.py index 3ebe41bc8bc27355d060bc89caf91424049417ea..0a997776efceae1c0f7a3614693aa2b5696e14a1 100644 --- a/transformer_engine/jax/softmax.py +++ b/transformer_engine/jax/softmax.py @@ -14,15 +14,18 @@ from . import cpp_extensions as tex class SoftmaxType(Enum): """SoftmaxType.""" + SCALED = "scaled" SCALED_MASKED = "scaled_masked" SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked" -def softmax(logits: jnp.ndarray, - mask: Optional[jnp.ndarray] = None, - scale_factor: Optional[float] = 1.0, - softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED): +def softmax( + logits: jnp.ndarray, + mask: Optional[jnp.ndarray] = None, + scale_factor: Optional[float] = 1.0, + softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED, +): """ Softmax wrapper """ @@ -50,7 +53,7 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type): def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz): - softmax_output, = ctx + (softmax_output,) = ctx if softmax_type is SoftmaxType.SCALED_MASKED: dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, scale_factor) diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py index 04f5f77e36fef59c8e6dd8e1708445433b31dd38..62fa1fe6263d15c832b6f490a34eb597e9f0cbd3 100644 --- a/transformer_engine/paddle/__init__.py +++ b/transformer_engine/paddle/__init__.py @@ -6,6 +6,7 @@ # pylint: disable=wrong-import-position,wrong-import-order + def _load_library(): """Load shared library with Transformer Engine C extensions""" from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import diff --git a/transformer_engine/paddle/constants.py b/transformer_engine/paddle/constants.py index 5ddcf7becc430095684a6ce30066471354d0b713..69d3859b8fc7f3c4eade6a4e0fa8f2099b4ec550 100644 --- a/transformer_engine/paddle/constants.py +++ b/transformer_engine/paddle/constants.py @@ -13,6 +13,7 @@ from transformer_engine import transformer_engine_paddle as tex class FP8FwdTensors(Enum): """Used as named indices on the `scale`, `scale_inv`, and `amax` tensors in the `FP8TensorMeta` class.""" + GEMM1_INPUT = 0 GEMM1_WEIGHT = 1 GEMM1_OUTPUT = 2 @@ -24,6 +25,7 @@ class FP8FwdTensors(Enum): class FP8BwdTensors(Enum): """Used as named indices on the `scale`, `scale_inv`, and `amax` tensors in the `FP8TensorMeta` class.""" + GRAD_OUTPUT1 = 0 GRAD_INPUT1 = 1 GRAD_OUTPUT2 = 2 @@ -51,7 +53,7 @@ GemmParallelModes = ("row", "column", None) dist_group_type = paddle.distributed.collective.Group -RecomputeFunctionNames = ('unpack', 'backward') +RecomputeFunctionNames = ("unpack", "backward") AttnBiasType = { "no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS, diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py index f2a8c5f394a0c40aa86ee0c099953b981a35b705..4a3763b1dcad84480afb54794dc6c9a4be393be2 100644 --- a/transformer_engine/paddle/cpp_extensions.py +++ b/transformer_engine/paddle/cpp_extensions.py @@ -66,8 +66,9 @@ def gemm( bias = bias if use_bias else None - assert A.dtype == dtype and B.dtype == dtype, \ - f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}' + assert ( + A.dtype == dtype and B.dtype == dtype + ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}" input_dtype = TE_DType[dtype] output_dtype = TE_DType[out.dtype] if use_bias: @@ -82,13 +83,13 @@ def gemm( None, grad_bias if grad else bias, out, - None, # out_scale - None, # out_amax + None, # out_scale + None, # out_amax gelu_input, workspace, - 0, # A_index - 0, # B_index - 0, # D_index + 0, # A_index + 0, # B_index + 0, # D_index int(input_dtype), int(input_dtype), int(output_dtype), @@ -98,8 +99,8 @@ def gemm( grad, workspace.shape[0], accumulate, - False, # use_split_accumulator - 0, # math_sm_count + False, # use_split_accumulator + 0, # math_sm_count ) return out, grad_bias, gelu_input @@ -168,7 +169,7 @@ def fp8_gemm( out, None if out_index is None else fp8_meta_tensor.scale, None if out_index is None else fp8_meta_tensor.amax_history, - gelu_input, # this is pre_gelu_out + gelu_input, # this is pre_gelu_out workspace, A_fp8_tensor.value, B_fp8_tensor.value, @@ -177,13 +178,13 @@ def fp8_gemm( int(B_dtype), int(out_dtype), int(bias_dtype), - True, # transa - False, # transb - False, # grad + True, # transa + False, # transb + False, # grad workspace.shape[0], accumulate, use_split_accumulator, - 0, # math_sm_count + 0, # math_sm_count ) return out, gelu_input @@ -270,8 +271,10 @@ def cast_transpose( dtype=paddle.uint8, ) else: - assert transpose_out.shape == [inp.shape[1], inp.shape[0] - ], "Transposed output shape does not match input shape." + assert transpose_out.shape == [ + inp.shape[1], + inp.shape[0], + ], "Transposed output shape does not match input shape." assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype." tex.te_cast_transpose( @@ -348,7 +351,9 @@ def swiglu( ) -def swiglu_pd(inp: paddle.Tensor,) -> paddle.Tensor: +def swiglu_pd( + inp: paddle.Tensor, +) -> paddle.Tensor: """Native SWIGLU""" gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1) out = F.silu(gate_out) * up_out @@ -423,11 +428,19 @@ def layernorm_fwd_fp8( zero_centered_gamma: bool = False, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """LayerNorm with FP8 output""" - out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(inp, weight, bias, fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, eps, - fp8_tensor.value, int(otype), sm_margin, - zero_centered_gamma) + out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8( + inp, + weight, + bias, + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + eps, + fp8_tensor.value, + int(otype), + sm_margin, + zero_centered_gamma, + ) return out, mu, rsigma @@ -480,10 +493,18 @@ def rmsnorm_fwd_fp8( zero_centered_gamma: bool = False, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """RMSNorm with FP8 output""" - out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, eps, fp8_tensor.value, - int(otype), sm_margin, zero_centered_gamma) + out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8( + inp, + weight, + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + eps, + fp8_tensor.value, + int(otype), + sm_margin, + zero_centered_gamma, + ) return out, rsigma @@ -533,8 +554,10 @@ def fused_attn_fwd_qkvpacked( ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Fused Attention FWD for packed QKV input""" - assert (qkv_dtype in (tex.DType.kBFloat16, - tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention." + assert qkv_dtype in ( + tex.DType.kBFloat16, + tex.DType.kFloat16, + ), "Only support bf16/fp16 for fused attention." b = cu_seqlens.shape[0] - 1 total_seqs = qkv.shape[0] * qkv.shape[1] @@ -546,17 +569,23 @@ def fused_attn_fwd_qkvpacked( if bias_type != "no_bias": assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert (Bias.shape == [1, h, max_seqlen, max_seqlen - ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." - assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv." - - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert Bias.shape == [ + 1, + h, + max_seqlen, + max_seqlen, + ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." + assert Bias.dtype == qkv.dtype, "bias tensor must be in the same dtype as qkv." + + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = (max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA - - 1) // BACKEND_F16m512_THREADS_PER_CTA + rng_elts_per_thread = ( + max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_THREADS_PER_CTA # BF16/FP16 fused attention API from fmha_v2 if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: @@ -571,15 +600,18 @@ def fused_attn_fwd_qkvpacked( if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype) elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype='float32') + softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype="float32") else: raise ValueError("Unsupported fused attention backend.") else: softmax_aux = None - rng_state = paddle.empty(shape=[ - 2, - ], dtype=paddle.int64) + rng_state = paddle.empty( + shape=[ + 2, + ], + dtype=paddle.int64, + ) # execute kernel tex.te_fused_attn_fwd_qkvpacked( @@ -625,8 +657,10 @@ def fused_attn_bwd_qkvpacked( ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Fused Attention BWD for packed QKV input""" - assert (qkv_dtype in (tex.DType.kBFloat16, - tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention." + assert qkv_dtype in ( + tex.DType.kBFloat16, + tex.DType.kFloat16, + ), "Only support bf16/fp16 for fused attention." b = cu_seqlens.shape[0] - 1 total_seqs = qkv.shape[0] * qkv.shape[1] @@ -636,8 +670,9 @@ def fused_attn_bwd_qkvpacked( if attn_scale is None: attn_scale = 1.0 / math.sqrt(d) - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." if set_zero: dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype) @@ -694,10 +729,13 @@ def fused_attn_fwd_kvpacked( ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Fused Attention FWD for packed KV input""" - assert (qkv_dtype in (tex.DType.kBFloat16, - tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention." - assert (cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" + assert qkv_dtype in ( + tex.DType.kBFloat16, + tex.DType.kFloat16, + ), "Only support bf16/fp16 for fused attention." + assert ( + cu_seqlens_q.shape == cu_seqlens_kv.shape + ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" b = cu_seqlens_q.shape[0] - 1 total_seqs_q = q.shape[0] * q.shape[1] @@ -710,17 +748,23 @@ def fused_attn_fwd_kvpacked( if bias_type != "no_bias": assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert (Bias.shape == [1, h, max_seqlen_q, max_seqlen_kv - ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." - assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv." - - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert Bias.shape == [ + 1, + h, + max_seqlen_q, + max_seqlen_kv, + ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." + assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as q and kv." + + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - - 1) // BACKEND_F16m512_THREADS_PER_CTA + rng_elts_per_thread = ( + max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_THREADS_PER_CTA # BF16/FP16 fused attention API from fmha_v2 if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: @@ -735,15 +779,18 @@ def fused_attn_fwd_kvpacked( if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32') + softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32") else: raise ValueError("Unsupported fused attention backend.") else: softmax_aux = None - rng_state = paddle.empty(shape=[ - 2, - ], dtype=paddle.int64) + rng_state = paddle.empty( + shape=[ + 2, + ], + dtype=paddle.int64, + ) # execute kernel tex.te_fused_attn_fwd_kvpacked( @@ -797,10 +844,13 @@ def fused_attn_bwd_kvpacked( ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Fused Attention BWD for packed KV input""" - assert (qkv_dtype in (tex.DType.kBFloat16, - tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention." - assert (cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" + assert qkv_dtype in ( + tex.DType.kBFloat16, + tex.DType.kFloat16, + ), "Only support bf16/fp16 for fused attention." + assert ( + cu_seqlens_q.shape == cu_seqlens_kv.shape + ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" b = cu_seqlens_q.shape[0] - 1 total_seqs_q = q.shape[0] * q.shape[1] @@ -811,8 +861,9 @@ def fused_attn_bwd_kvpacked( if attn_scale is None: attn_scale = 1.0 / math.sqrt(d) - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." if set_zero: dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) @@ -875,12 +926,16 @@ def fused_attn_fwd( ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Fused Attention FWD for unpacked QKV input""" - assert (qkv_dtype in (tex.DType.kBFloat16, - tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention." - assert (cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - assert (qkv_layout == "bshd_bshd_bshd" - ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." + assert qkv_dtype in ( + tex.DType.kBFloat16, + tex.DType.kFloat16, + ), "Only support bf16/fp16 for fused attention." + assert ( + cu_seqlens_q.shape == cu_seqlens_kv.shape + ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" + assert ( + qkv_layout == "bshd_bshd_bshd" + ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." b = cu_seqlens_q.shape[0] - 1 h = q.shape[-2] @@ -891,18 +946,23 @@ def fused_attn_fwd( if bias_type != "no_bias": assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert (Bias.shape == [ - 1, h, max_seqlen_q, max_seqlen_kv - ]), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." - assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as qkv." - - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert Bias.shape == [ + 1, + h, + max_seqlen_q, + max_seqlen_kv, + ], "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." + assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as qkv." + + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - - 1) // BACKEND_F16m512_THREADS_PER_CTA + rng_elts_per_thread = ( + max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_THREADS_PER_CTA # BF16/FP16 fused attention API from fmha_v2 if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: @@ -917,15 +977,18 @@ def fused_attn_fwd( if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32') + softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32") else: raise ValueError("Unsupported fused attention backend.") else: softmax_aux = None - rng_state = paddle.empty(shape=[ - 2, - ], dtype=paddle.int64) + rng_state = paddle.empty( + shape=[ + 2, + ], + dtype=paddle.int64, + ) # execute kernel tex.te_fused_attn_fwd( @@ -978,12 +1041,16 @@ def fused_attn_bwd( ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Fused Attention BWD for packed KV input""" - assert (qkv_dtype in (tex.DType.kBFloat16, - tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention." - assert (cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - assert (qkv_layout == "bshd_bshd_bshd" - ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." + assert qkv_dtype in ( + tex.DType.kBFloat16, + tex.DType.kFloat16, + ), "Only support bf16/fp16 for fused attention." + assert ( + cu_seqlens_q.shape == cu_seqlens_kv.shape + ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" + assert ( + qkv_layout == "bshd_bshd_bshd" + ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." b = cu_seqlens_q.shape[0] - 1 h = q.shape[-2] @@ -992,8 +1059,9 @@ def fused_attn_bwd( if attn_scale is None: attn_scale = 1.0 / math.sqrt(d) - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." if set_zero: dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) @@ -1041,7 +1109,7 @@ def scaled_softmax_forward( inp: paddle.Tensor, scale_factor: float, ) -> paddle.Tensor: - """ scaled softmax forward""" + """scaled softmax forward""" return tex.te_scaled_softmax_forward(inp, scale_factor) @@ -1050,7 +1118,7 @@ def scaled_softmax_backward( softmax_results: paddle.Tensor, scale_factor: float, ) -> paddle.Tensor: - """ scaled softmax backward""" + """scaled softmax backward""" tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor) return out_grad @@ -1060,7 +1128,7 @@ def scaled_masked_softmax_forward( mask: paddle.Tensor, scale_factor: float, ) -> paddle.Tensor: - """ scaled masked softmax forward""" + """scaled masked softmax forward""" return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor) @@ -1070,7 +1138,7 @@ def scaled_masked_softmax_backward( softmax_results: paddle.Tensor, scale_factor: float, ) -> paddle.Tensor: - """ scaled masked softmax backward""" + """scaled masked softmax backward""" tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor) return out_grad @@ -1079,7 +1147,7 @@ def scaled_upper_triang_masked_softmax_forward( inp: paddle.Tensor, scale_factor: float, ) -> paddle.Tensor: - """ scaled upper triang masked softmax forward""" + """scaled upper triang masked softmax forward""" return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor) @@ -1088,6 +1156,6 @@ def scaled_upper_triang_masked_softmax_backward( softmax_results: paddle.Tensor, scale_factor: float, ) -> paddle.Tensor: - """ scaled upper triang masked softmax backward""" + """scaled upper triang masked softmax backward""" tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor) return out_grad diff --git a/transformer_engine/paddle/csrc/common.cpp b/transformer_engine/paddle/csrc/common.cpp index b0d39f1205e3a7514378b4ea7678466bdf987463..5e35a28a6bc70e7f0b2b6ec25fc902209635c0e3 100644 --- a/transformer_engine/paddle/csrc/common.cpp +++ b/transformer_engine/paddle/csrc/common.cpp @@ -11,75 +11,73 @@ namespace paddle_ext { TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector &shape, const DType type) { - return TensorWrapper(const_cast(data_ptr), shape, type); + return TensorWrapper(const_cast(data_ptr), shape, type); } TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) { - return TensorWrapper(data_ptr, shape, type); + return TensorWrapper(data_ptr, shape, type); } TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector &shape, const DType type, void *amax_ptr, void *scale_ptr, void *scale_inv_ptr) { - return TensorWrapper(data_ptr, shape, type, reinterpret_cast(amax_ptr), - reinterpret_cast(scale_ptr), - reinterpret_cast(scale_inv_ptr)); + return TensorWrapper(data_ptr, shape, type, reinterpret_cast(amax_ptr), + reinterpret_cast(scale_ptr), + reinterpret_cast(scale_inv_ptr)); } TensorWrapper MakeNvteTensor(paddle::Tensor &tensor) { // NOLINT - return MakeNvteTensor(tensor.data(), GetShapeArray(tensor), Paddle2NvteDType(tensor.dtype())); + return MakeNvteTensor(tensor.data(), GetShapeArray(tensor), Paddle2NvteDType(tensor.dtype())); } TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor) { - return MakeNvteTensor(const_cast(tensor.data()), GetShapeArray(tensor), - Paddle2NvteDType(tensor.dtype())); + return MakeNvteTensor(const_cast(tensor.data()), GetShapeArray(tensor), + Paddle2NvteDType(tensor.dtype())); } paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, bool init_to_zeros) { - auto size = shape.ndim; - if (size == 2 && init_to_zeros) { - return paddle::zeros( - {static_cast(shape.data[0]), static_cast(shape.data[1])}, - Nvte2PaddleDType(type), place); - } else if (size == 2) { - return paddle::empty( - {static_cast(shape.data[0]), static_cast(shape.data[1])}, - Nvte2PaddleDType(type), place); - } else if (size == 1 && init_to_zeros) { - return paddle::zeros({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); - } else if (size == 1) { - return paddle::empty({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); - } - NVTE_CHECK(false, "Should never reach here! func: AllocateSpace"); + auto size = shape.ndim; + if (size == 2 && init_to_zeros) { + return paddle::zeros({static_cast(shape.data[0]), static_cast(shape.data[1])}, + Nvte2PaddleDType(type), place); + } else if (size == 2) { + return paddle::empty({static_cast(shape.data[0]), static_cast(shape.data[1])}, + Nvte2PaddleDType(type), place); + } else if (size == 1 && init_to_zeros) { + return paddle::zeros({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); + } else if (size == 1) { + return paddle::empty({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); + } + NVTE_CHECK(false, "Should never reach here! func: AllocateSpace"); } // MHA utils // convert QKV layout to enum NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout) { - static const std::unordered_map layout_map = { - {"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD}, - {"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D}, - {"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD}, - {"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D}, - {"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD}, - {"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD}, - {"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D}, - {"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD}, - {"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D}, - {"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD}, - {"t3hd", NVTE_QKV_Layout::NVTE_T3HD}, - {"th3d", NVTE_QKV_Layout::NVTE_TH3D}, - {"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD}, - {"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D}, - {"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD}, - }; + static const std::unordered_map layout_map = { + {"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD}, + {"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D}, + {"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD}, + {"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D}, + {"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD}, + {"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD}, + {"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D}, + {"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD}, + {"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D}, + {"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD}, + {"t3hd", NVTE_QKV_Layout::NVTE_T3HD}, + {"th3d", NVTE_QKV_Layout::NVTE_TH3D}, + {"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD}, + {"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D}, + {"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD}, + }; - auto it = layout_map.find(qkv_layout); - if (it != layout_map.end()) { - return it->second; - } else { - NVTE_ERROR("Invalid QKV layout string: " + qkv_layout); - } + auto it = layout_map.find(qkv_layout); + if (it != layout_map.end()) { + return it->second; + } else { + NVTE_ERROR("Invalid QKV layout string: " + qkv_layout); + } } } // namespace paddle_ext diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index 3224f98ebc928c35ce352f6183e912705bd12996..3f1be258785afadec908dff2723175bc8962e53f 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -5,13 +5,7 @@ ************************************************************************/ #pragma once -#include -#include - #include -#include "paddle/extension.h" -#include "paddle/phi/backends/all_context.h" - #include #include #include @@ -22,56 +16,62 @@ #include #include #include + +#include +#include + #include "common/util/logging.h" +#include "paddle/extension.h" +#include "paddle/phi/backends/all_context.h" namespace transformer_engine { namespace paddle_ext { // Paddle Tensor Utils template inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) { - if (index < 0 || index >= x.numel()) { - NVTE_ERROR("Index out of bound"); - } - return reinterpret_cast(x.data() + static_cast(index)); + if (index < 0 || index >= x.numel()) { + NVTE_ERROR("Index out of bound"); + } + return reinterpret_cast(x.data() + static_cast(index)); } template inline void *GetDataPtr(paddle::Tensor &x, int64_t index) { // NOLINT - if (index < 0 || index >= x.numel()) { - NVTE_ERROR("Index out of bound"); - } - return reinterpret_cast(x.data() + static_cast(index)); + if (index < 0 || index >= x.numel()) { + NVTE_ERROR("Index out of bound"); + } + return reinterpret_cast(x.data() + static_cast(index)); } template inline const void *GetOptionalDataPtr(const paddle::optional &x, int64_t index) { - return x ? GetDataPtr(*x, index) : nullptr; + return x ? GetDataPtr(*x, index) : nullptr; } template inline void *GetOptionalDataPtr(paddle::optional &x, int64_t index) { // NOLINT - return x ? GetDataPtr(*x, index) : nullptr; + return x ? GetDataPtr(*x, index) : nullptr; } inline const void *GetOptionalDataPtr(const paddle::optional &x) { - return x ? x->data() : nullptr; + return x ? x->data() : nullptr; } inline void *GetOptionalDataPtr(paddle::optional &x) { // NOLINT - return x ? x->data() : nullptr; + return x ? x->data() : nullptr; } inline std::vector GetShapeArray(const paddle::Tensor &x) { - std::vector shapes; - for (auto dim : x.shape()) { - shapes.push_back(static_cast(dim)); - } - return shapes; + std::vector shapes; + for (auto dim : x.shape()) { + shapes.push_back(static_cast(dim)); + } + return shapes; } inline std::vector GetShapeArray(const paddle::optional &x) { - if (x) return GetShapeArray(x.get()); - return {0}; + if (x) return GetShapeArray(x.get()); + return {0}; } paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, @@ -79,96 +79,96 @@ paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const pad // DType Utils inline paddle::DataType Nvte2PaddleDType(DType t) { - switch (t) { - case DType::kInt32: - case DType::kFloat32: - return paddle::DataType::FLOAT32; - case DType::kFloat16: - return paddle::DataType::FLOAT16; - case DType::kBFloat16: - return paddle::DataType::BFLOAT16; - case DType::kByte: - case DType::kFloat8E4M3: - case DType::kFloat8E5M2: - return paddle::DataType::UINT8; - default: - NVTE_ERROR("Invalid type"); - } + switch (t) { + case DType::kInt32: + case DType::kFloat32: + return paddle::DataType::FLOAT32; + case DType::kFloat16: + return paddle::DataType::FLOAT16; + case DType::kBFloat16: + return paddle::DataType::BFLOAT16; + case DType::kByte: + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + return paddle::DataType::UINT8; + default: + NVTE_ERROR("Invalid type"); + } } inline DType Paddle2NvteDType(paddle::DataType t) { - switch (t) { - case paddle::DataType::FLOAT16: - return DType::kFloat16; - case paddle::DataType::FLOAT32: - return DType::kFloat32; - case paddle::DataType::BFLOAT16: - return DType::kBFloat16; - case paddle::DataType::BOOL: - return DType::kByte; - case paddle::DataType::UINT8: - return DType::kByte; - case paddle::DataType::INT32: - return DType::kInt32; - case paddle::DataType::INT64: - return DType::kInt64; - default: - NVTE_ERROR("Invalid type"); - } + switch (t) { + case paddle::DataType::FLOAT16: + return DType::kFloat16; + case paddle::DataType::FLOAT32: + return DType::kFloat32; + case paddle::DataType::BFLOAT16: + return DType::kBFloat16; + case paddle::DataType::BOOL: + return DType::kByte; + case paddle::DataType::UINT8: + return DType::kByte; + case paddle::DataType::INT32: + return DType::kInt32; + case paddle::DataType::INT64: + return DType::kInt64; + default: + NVTE_ERROR("Invalid type"); + } } inline DType Int2NvteDType(int64_t dtype) { - if (dtype >= 0 && dtype < static_cast(DType::kNumTypes)) { - return static_cast(dtype); - } else { - NVTE_ERROR("Type not supported."); - } + if (dtype >= 0 && dtype < static_cast(DType::kNumTypes)) { + return static_cast(dtype); + } else { + NVTE_ERROR("Type not supported."); + } } // get the fused attention backend inline NVTE_Fused_Attn_Backend get_fused_attn_backend( const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) { - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim); - return fused_attention_backend; + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim) { + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend(static_cast(q_dtype), static_cast(kv_dtype), + qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, + num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim); + return fused_attention_backend; } // CUDA Utils class cudaDevicePropertiesManager { public: - static cudaDevicePropertiesManager &Instance() { - static thread_local cudaDevicePropertiesManager instance; - return instance; + static cudaDevicePropertiesManager &Instance() { + static thread_local cudaDevicePropertiesManager instance; + return instance; + } + + int GetMultiProcessorCount() { + if (!prop_queried_) { + int device_id; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + cudaGetDeviceProperties(&prop_, device_id); + prop_queried_ = true; } - - int GetMultiProcessorCount() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.multiProcessorCount; - } - - int GetMajor() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.major; + return prop_.multiProcessorCount; + } + + int GetMajor() { + if (!prop_queried_) { + int device_id; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + cudaGetDeviceProperties(&prop_, device_id); + prop_queried_ = true; } + return prop_.major; + } private: - bool prop_queried_ = false; - cudaDeviceProp prop_; + bool prop_queried_ = false; + cudaDeviceProp prop_; }; // NVTE Tensor Utils diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 68d22d45c64c4d298e78bf9d492f399b4103cb5a..1ba7f8ed3e4f16998383154409992c8317604e0a 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -16,28 +16,28 @@ namespace paddle_ext { // convert bias type to enum NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { - if (bias_type == "no_bias") { - return NVTE_Bias_Type::NVTE_NO_BIAS; - } else if (bias_type == "pre_scale_bias") { - return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS; - } else if (bias_type == "post_scale_bias") { - return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; - } else { - NVTE_ERROR("Invalid bias type. \n"); - } + if (bias_type == "no_bias") { + return NVTE_Bias_Type::NVTE_NO_BIAS; + } else if (bias_type == "pre_scale_bias") { + return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS; + } else if (bias_type == "post_scale_bias") { + return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; + } else { + NVTE_ERROR("Invalid bias type. \n"); + } } // convert attn mask type to enum NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) { - if (mask_type == "padding") { - return NVTE_Mask_Type::NVTE_PADDING_MASK; - } else if (mask_type == "causal") { - return NVTE_Mask_Type::NVTE_CAUSAL_MASK; - } else if (mask_type == "no_mask") { - return NVTE_Mask_Type::NVTE_NO_MASK; - } else { - NVTE_ERROR("Invalid attention mask type. \n"); - } + if (mask_type == "padding") { + return NVTE_Mask_Type::NVTE_PADDING_MASK; + } else if (mask_type == "causal") { + return NVTE_Mask_Type::NVTE_CAUSAL_MASK; + } else if (mask_type == "no_mask") { + return NVTE_Mask_Type::NVTE_NO_MASK; + } else { + NVTE_ERROR("Invalid attention mask type. \n"); + } } void cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, @@ -45,46 +45,46 @@ void cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, paddle::Tensor &amax, // NOLINT paddle::Tensor &scale_inv, // NOLINT int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); + auto shape = GetShapeArray(input); - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), shape, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); + auto input_cu = MakeNvteTensor(input); + auto output_cu = MakeNvteTensor( + output.data(), shape, Int2NvteDType(otype), GetDataPtr(amax, index), + const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream()); + nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream()); } std::vector cast_from_fp8(const paddle::Tensor &input, const paddle::Tensor &scale_inv, int64_t index, int64_t itype, int64_t otype) { - auto shape = GetShapeArray(input); + auto shape = GetShapeArray(input); - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype))); - auto input_cu = - MakeNvteTensor(const_cast(input.data()), shape, Int2NvteDType(itype), nullptr, - nullptr, const_cast(GetDataPtr(scale_inv, index))); - auto output_cu = MakeNvteTensor(output); + auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype))); + auto input_cu = + MakeNvteTensor(const_cast(input.data()), shape, Int2NvteDType(itype), nullptr, + nullptr, const_cast(GetDataPtr(scale_inv, index))); + auto output_cu = MakeNvteTensor(output); - nvte_fp8_dequantize(input_cu.data(), output_cu.data(), input.stream()); + nvte_fp8_dequantize(input_cu.data(), output_cu.data(), input.stream()); - return {output}; + return {output}; } std::vector te_transpose(const paddle::Tensor &input, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - size_t M = shape[0]; - size_t N = shape[1]; + auto shape = GetShapeArray(input); + NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); + size_t M = shape[0]; + size_t N = shape[1]; - auto output = paddle::empty({input.shape()[1], input.shape()[0]}, input.dtype(), input.place()); + auto output = paddle::empty({input.shape()[1], input.shape()[0]}, input.dtype(), input.place()); - auto input_cu = MakeNvteTensor(const_cast(input.data()), {M, N}, Int2NvteDType(otype)); - auto output_cu = MakeNvteTensor(output.data(), {N, M}, Int2NvteDType(otype)); + auto input_cu = MakeNvteTensor(const_cast(input.data()), {M, N}, Int2NvteDType(otype)); + auto output_cu = MakeNvteTensor(output.data(), {N, M}, Int2NvteDType(otype)); - nvte_transpose(input_cu.data(), output_cu.data(), input.stream()); + nvte_transpose(input_cu.data(), output_cu.data(), input.stream()); - return {output}; + return {output}; } void te_cast_transpose(const paddle::Tensor &input, const paddle::Tensor &scale, @@ -93,23 +93,23 @@ void te_cast_transpose(const paddle::Tensor &input, const paddle::Tensor &scale, paddle::Tensor &amax, // NOLINT paddle::Tensor &scale_inv, // NOLINT int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto input_cu = MakeNvteTensor(input); - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - auto output_transpose_cu = MakeNvteTensor(output_transpose.data(), {N, M}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - - nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - input.stream()); + auto shape = GetShapeArray(input); + NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); + + size_t M = shape[0]; + size_t N = shape[1]; + + auto input_cu = MakeNvteTensor(input); + void *amax_data = GetDataPtr(amax, index); + void *scale_data = const_cast(GetDataPtr(scale, index)); + void *scale_inv_data = GetDataPtr(scale_inv, index); + auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype), amax_data, + scale_data, scale_inv_data); + auto output_transpose_cu = MakeNvteTensor(output_transpose.data(), {N, M}, Int2NvteDType(otype), + amax_data, scale_data, scale_inv_data); + + nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), + input.stream()); } std::vector te_cast_transpose_bgrad(const paddle::Tensor &grad_output, @@ -117,43 +117,43 @@ std::vector te_cast_transpose_bgrad(const paddle::Tensor &grad_o paddle::Tensor &amax, // NOLINT paddle::Tensor &scale_inv, // NOLINT int64_t index, int64_t otype) { - auto shape = GetShapeArray(grad_output); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto grad_bias = - paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); - auto grad_output_cast = paddle::empty_like(grad_output, Nvte2PaddleDType(Int2NvteDType(otype)), - grad_output.place()); - auto grad_output_transpose = - paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, - Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place()); - - auto input_cu = MakeNvteTensor(grad_output); - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - auto output_cast_cu = MakeNvteTensor(grad_output_cast.data(), {M, N}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - auto output_transpose_cu = - MakeNvteTensor(grad_output_transpose.data(), {N, M}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto dbias_cu = MakeNvteTensor(grad_bias); - transformer_engine::TensorWrapper workspace; - - nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - dbias_cu.data(), workspace.data(), grad_output.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - dbias_cu.data(), workspace.data(), grad_output.stream()); - - return {grad_bias, grad_output_cast, grad_output_transpose}; + auto shape = GetShapeArray(grad_output); + NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); + + size_t M = shape[0]; + size_t N = shape[1]; + + auto grad_bias = + paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); + auto grad_output_cast = + paddle::empty_like(grad_output, Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place()); + auto grad_output_transpose = + paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, + Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place()); + + auto input_cu = MakeNvteTensor(grad_output); + void *amax_data = GetDataPtr(amax, index); + void *scale_data = const_cast(GetDataPtr(scale, index)); + void *scale_inv_data = GetDataPtr(scale_inv, index); + auto output_cast_cu = MakeNvteTensor(grad_output_cast.data(), {M, N}, Int2NvteDType(otype), + amax_data, scale_data, scale_inv_data); + auto output_transpose_cu = + MakeNvteTensor(grad_output_transpose.data(), {N, M}, Int2NvteDType(otype), amax_data, + scale_data, scale_inv_data); + auto dbias_cu = MakeNvteTensor(grad_bias); + transformer_engine::TensorWrapper workspace; + + nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), + dbias_cu.data(), workspace.data(), grad_output.stream()); + + // Fill workspace + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + + nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), + dbias_cu.data(), workspace.data(), grad_output.stream()); + + return {grad_bias, grad_output_cast, grad_output_transpose}; } void te_gemm(const paddle::Tensor &A, const paddle::optional &A_scale_inverse, @@ -166,116 +166,115 @@ void te_gemm(const paddle::Tensor &A, const paddle::optional &A_ int64_t D_type, int64_t bias_type, bool transa, bool transb, bool grad, int64_t workspace_size, bool accumulate, bool use_split_accumulator, int64_t math_sm_count) { - auto te_A = MakeNvteTensor( - const_cast(A.data()), GetShapeArray(A), Int2NvteDType(A_type), nullptr, nullptr, - const_cast(GetOptionalDataPtr(A_scale_inverse, A_index))); - auto te_B = MakeNvteTensor( - const_cast(B.data()), GetShapeArray(B), Int2NvteDType(B_type), nullptr, nullptr, - const_cast(GetOptionalDataPtr(B_scale_inverse, B_index))); - auto te_D = MakeNvteTensor(D.data(), GetShapeArray(D), Int2NvteDType(D_type), - GetOptionalDataPtr(D_amax, D_index), - GetOptionalDataPtr(D_scale, D_index), nullptr); - - auto te_bias = MakeNvteTensor(const_cast(GetOptionalDataPtr(bias)), GetShapeArray(bias), - Int2NvteDType(bias_type)); - - DType gelu_dtype = - pre_gelu_out ? Paddle2NvteDType(pre_gelu_out->dtype()) : Int2NvteDType(D_type); - auto te_pre_gelu_out = - MakeNvteTensor(GetOptionalDataPtr(pre_gelu_out), GetShapeArray(pre_gelu_out), gelu_dtype); - auto te_workspace = - MakeNvteTensor(workspace.data(), {static_cast(workspace_size)}, DType::kByte); - - nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), - transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, A.stream()); + auto te_A = MakeNvteTensor( + const_cast(A.data()), GetShapeArray(A), Int2NvteDType(A_type), nullptr, nullptr, + const_cast(GetOptionalDataPtr(A_scale_inverse, A_index))); + auto te_B = MakeNvteTensor( + const_cast(B.data()), GetShapeArray(B), Int2NvteDType(B_type), nullptr, nullptr, + const_cast(GetOptionalDataPtr(B_scale_inverse, B_index))); + auto te_D = MakeNvteTensor(D.data(), GetShapeArray(D), Int2NvteDType(D_type), + GetOptionalDataPtr(D_amax, D_index), + GetOptionalDataPtr(D_scale, D_index), nullptr); + + auto te_bias = MakeNvteTensor(const_cast(GetOptionalDataPtr(bias)), GetShapeArray(bias), + Int2NvteDType(bias_type)); + + DType gelu_dtype = pre_gelu_out ? Paddle2NvteDType(pre_gelu_out->dtype()) : Int2NvteDType(D_type); + auto te_pre_gelu_out = + MakeNvteTensor(GetOptionalDataPtr(pre_gelu_out), GetShapeArray(pre_gelu_out), gelu_dtype); + auto te_workspace = + MakeNvteTensor(workspace.data(), {static_cast(workspace_size)}, DType::kByte); + + nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), + transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, A.stream()); } std::vector te_gelu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, paddle::Tensor &amax, // NOLINT paddle::Tensor &scale_inv, // NOLINT int64_t index, int64_t otype) { - auto output = paddle::empty_like(input, Nvte2PaddleDType(DType::kByte), input.place()); + auto output = paddle::empty_like(input, Nvte2PaddleDType(DType::kByte), input.place()); - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), GetShapeArray(input), Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); + auto input_cu = MakeNvteTensor(input); + auto output_cu = MakeNvteTensor( + output.data(), GetShapeArray(input), Int2NvteDType(otype), GetDataPtr(amax, index), + const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); + nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); - return {output}; + return {output}; } std::vector te_gelu(const paddle::Tensor &input, int64_t otype) { - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); + auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(input), Int2NvteDType(otype)); + auto input_cu = MakeNvteTensor(input); + auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(input), Int2NvteDType(otype)); - nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); + nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); - return {output}; + return {output}; } std::vector te_swiglu(const paddle::Tensor &input, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); + auto shape = GetShapeArray(input); + NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - size_t M = shape[0]; - size_t N = shape[1]; + size_t M = shape[0]; + size_t N = shape[1]; - auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, - Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); + auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, + Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(output), Int2NvteDType(otype)); + auto input_cu = MakeNvteTensor(input); + auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(output), Int2NvteDType(otype)); - nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); + nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); - return {output}; + return {output}; } std::vector te_swiglu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, paddle::Tensor &amax, // NOLINT paddle::Tensor &scale_inv, // NOLINT int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); + auto shape = GetShapeArray(input); + NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - size_t M = shape[0]; - size_t N = shape[1]; + size_t M = shape[0]; + size_t N = shape[1]; - auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, - Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); + auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, + Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), GetShapeArray(output), Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); + auto input_cu = MakeNvteTensor(input); + auto output_cu = MakeNvteTensor( + output.data(), GetShapeArray(output), Int2NvteDType(otype), GetDataPtr(amax, index), + const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); + nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); - return {output}; + return {output}; } std::vector te_dswiglu(const paddle::Tensor &grad, const paddle::Tensor &input, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); + auto shape = GetShapeArray(input); + NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - size_t M = shape[0]; - size_t N = shape[1]; + size_t M = shape[0]; + size_t N = shape[1]; - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); + auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto input_cu = MakeNvteTensor(input.data(), {M, N}, Paddle2NvteDType(input.dtype())); - auto grad_cu = MakeNvteTensor(grad.data(), {M, N / 2}, Paddle2NvteDType(grad.dtype())); - auto output_cu = MakeNvteTensor(output.data(), {M, N}, Paddle2NvteDType(output.dtype())); + auto input_cu = MakeNvteTensor(input.data(), {M, N}, Paddle2NvteDType(input.dtype())); + auto grad_cu = MakeNvteTensor(grad.data(), {M, N / 2}, Paddle2NvteDType(grad.dtype())); + auto output_cu = MakeNvteTensor(output.data(), {M, N}, Paddle2NvteDType(output.dtype())); - nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), input.stream()); + nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), input.stream()); - return {output}; + return {output}; } std::vector te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output, @@ -284,49 +283,48 @@ std::vector te_cast_transpose_bgrad_dgelu(const paddle::Tensor & paddle::Tensor &amax, // NOLINT paddle::Tensor &scale_inv, // NOLINT int64_t index, int64_t otype) { - auto shape = GetShapeArray(grad_output); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); + auto shape = GetShapeArray(grad_output); + NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - size_t M = shape[0]; - size_t N = shape[1]; + size_t M = shape[0]; + size_t N = shape[1]; - // DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = - paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); + // DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); + auto grad_bias = + paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); - auto dgelu = - paddle::empty_like(grad_output, Nvte2PaddleDType(DType::kByte), grad_output.place()); + auto dgelu = paddle::empty_like(grad_output, Nvte2PaddleDType(DType::kByte), grad_output.place()); - auto dgelu_transpose = paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, - Nvte2PaddleDType(DType::kByte), grad_output.place()); + auto dgelu_transpose = paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, + Nvte2PaddleDType(DType::kByte), grad_output.place()); - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); + void *amax_data = GetDataPtr(amax, index); + void *scale_data = const_cast(GetDataPtr(scale, index)); + void *scale_inv_data = GetDataPtr(scale_inv, index); - TensorWrapper workspace; + TensorWrapper workspace; - auto gelu_input_cu = MakeNvteTensor(gelu_input); - auto input_cu = MakeNvteTensor(grad_output); - auto cast_output_cu = MakeNvteTensor(dgelu.data(), {M, N}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto transposed_output_cu = MakeNvteTensor(dgelu_transpose.data(), {N, M}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - auto dbias_cu = MakeNvteTensor(grad_bias); + auto gelu_input_cu = MakeNvteTensor(gelu_input); + auto input_cu = MakeNvteTensor(grad_output); + auto cast_output_cu = MakeNvteTensor(dgelu.data(), {M, N}, Int2NvteDType(otype), amax_data, + scale_data, scale_inv_data); + auto transposed_output_cu = MakeNvteTensor(dgelu_transpose.data(), {N, M}, Int2NvteDType(otype), + amax_data, scale_data, scale_inv_data); + auto dbias_cu = MakeNvteTensor(grad_bias); - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - grad_output.stream()); + nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), + transposed_output_cu.data(), dbias_cu.data(), workspace.data(), + grad_output.stream()); - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + // Fill workspace + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - grad_output.stream()); + nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), + transposed_output_cu.data(), dbias_cu.data(), workspace.data(), + grad_output.stream()); - return {dgelu, dgelu_transpose, grad_bias}; + return {dgelu, dgelu_transpose, grad_bias}; } std::vector te_layernorm_fwd_fp8(const paddle::Tensor &input, @@ -337,171 +335,167 @@ std::vector te_layernorm_fwd_fp8(const paddle::Tensor &input, paddle::Tensor &scale_inv, // NOLINT float eps, int64_t index, int64_t otype, int64_t sm_margin, bool zero_centered_gamma) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto rsigma = - paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto beta_cu = MakeNvteTensor(bias); - auto z_cu = MakeNvteTensor( - ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - - // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - return {ln_out, mu, rsigma}; + auto shape = GetShapeArray(input); + NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); + + size_t N = shape[0]; + size_t H = shape[1]; + + auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); + auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); + auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); + auto input_cu = MakeNvteTensor(input); + auto gamma_cu = MakeNvteTensor(weight); + auto beta_cu = MakeNvteTensor(bias); + auto z_cu = MakeNvteTensor( + ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), + const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); + auto mu_cu = MakeNvteTensor(mu); + auto rsigma_cu = MakeNvteTensor(rsigma); + TensorWrapper workspace, barrier; + + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + + // This call populates workspace and barrier tensors with the required config + const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), + rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + + // Fill workspace and barrier + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); + auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); + + // Actual call to fwd kernel + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), + rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + + return {ln_out, mu, rsigma}; } std::vector te_layernorm_fwd(const paddle::Tensor &input, const paddle::Tensor &weight, const paddle::Tensor &bias, float eps, int64_t otype, int64_t sm_margin, bool zero_centered_gamma) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); - auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto rsigma = - paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto beta_cu = MakeNvteTensor(bias); - auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - - // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - return {ln_out, mu, rsigma}; + auto shape = GetShapeArray(input); + NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); + + size_t N = shape[0]; + size_t H = shape[1]; + + auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); + auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); + auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); + auto input_cu = MakeNvteTensor(input); + auto gamma_cu = MakeNvteTensor(weight); + auto beta_cu = MakeNvteTensor(bias); + auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); + auto mu_cu = MakeNvteTensor(mu); + auto rsigma_cu = MakeNvteTensor(rsigma); + TensorWrapper workspace, barrier; + + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + + // This call populates workspace and barrier tensors with the required config + const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), + rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + + // Fill workspace and barrier + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); + auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); + + // Actual call to fwd kernel + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), + rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + + return {ln_out, mu, rsigma}; } std::vector te_layernorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x, const paddle::Tensor &mu, const paddle::Tensor &rsigma, const paddle::Tensor &gamma, int64_t sm_margin, bool zero_centered_gamma) { - auto dx = paddle::empty_like(x, x.dtype(), x.place()); - auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - - TensorWrapper workspace, barrier, dgamma_part, dbeta_part; - - auto dz_cu = MakeNvteTensor(dz); - auto x_cu = MakeNvteTensor(x); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - auto gamma_cu = MakeNvteTensor(gamma); - auto dx_cu = MakeNvteTensor(dx); - auto dgamma_cu = MakeNvteTensor(dgamma); - auto dbeta_cu = MakeNvteTensor(dbeta); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), - dz.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - // Alloc space for Tensors. - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); - auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); - auto dbeta_part_data = AllocateSpace(dbeta_part.shape(), dbeta_part.dtype(), x.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); - dbeta_part = MakeNvteTensor(dbeta_part_data.data(), dbeta_part.shape(), dbeta_part.dtype()); - - // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), - dz.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - return {dx, dgamma, dbeta}; + auto dx = paddle::empty_like(x, x.dtype(), x.place()); + auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); + auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); + + TensorWrapper workspace, barrier, dgamma_part, dbeta_part; + + auto dz_cu = MakeNvteTensor(dz); + auto x_cu = MakeNvteTensor(x); + auto mu_cu = MakeNvteTensor(mu); + auto rsigma_cu = MakeNvteTensor(rsigma); + auto gamma_cu = MakeNvteTensor(gamma); + auto dx_cu = MakeNvteTensor(dx); + auto dgamma_cu = MakeNvteTensor(dgamma); + auto dbeta_cu = MakeNvteTensor(dbeta); + + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + + // This call populates tensors with the required config. + const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; + bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(), + num_sm - sm_margin, workspace.data(), barrier.data()); + + // Alloc space for Tensors. + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); + auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); + auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); + auto dbeta_part_data = AllocateSpace(dbeta_part.shape(), dbeta_part.dtype(), x.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); + dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); + dbeta_part = MakeNvteTensor(dbeta_part_data.data(), dbeta_part.shape(), dbeta_part.dtype()); + + // Actual call to bwd kernel. + bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(), + num_sm - sm_margin, workspace.data(), barrier.data()); + + return {dx, dgamma, dbeta}; } std::vector te_rmsnorm_fwd(const paddle::Tensor &input, const paddle::Tensor &weight, float eps, int64_t otype, int64_t sm_margin, bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, - "zero_centered_gamma is not supported yet for RMSNorm."); - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); + NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); + auto shape = GetShapeArray(input); + NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - size_t N = shape[0]; - size_t H = shape[1]; + size_t N = shape[0]; + size_t H = shape[1]; - auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); - auto rsigma = - paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; + auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); + auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); + auto input_cu = MakeNvteTensor(input); + auto gamma_cu = MakeNvteTensor(weight); + auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); + auto rsigma_cu = MakeNvteTensor(rsigma); + TensorWrapper workspace, barrier; - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - // This call populates workspace and barrier tensors with the required config + // This call populates workspace and barrier tensors with the required config - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - // Fill workspace and barrier - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); + // Fill workspace and barrier + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); + auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - // Actual call to fwd kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + // Actual call to fwd kernel + nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - return {ln_out, rsigma}; + return {ln_out, rsigma}; } std::vector te_rmsnorm_fwd_fp8(const paddle::Tensor &input, @@ -511,88 +505,85 @@ std::vector te_rmsnorm_fwd_fp8(const paddle::Tensor &input, paddle::Tensor &scale_inv, // NOLINT float eps, int64_t index, int64_t otype, int64_t sm_margin, bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, - "zero_centered_gamma is not supported yet for RMSNorm."); - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto rsigma = - paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto z_cu = MakeNvteTensor( - ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace and barrier tensors with the required config - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - - // Actual call to fwd kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - return {ln_out, rsigma}; + NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); + auto shape = GetShapeArray(input); + NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); + + size_t N = shape[0]; + size_t H = shape[1]; + + auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); + auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); + auto input_cu = MakeNvteTensor(input); + auto gamma_cu = MakeNvteTensor(weight); + auto z_cu = MakeNvteTensor( + ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), + const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); + auto rsigma_cu = MakeNvteTensor(rsigma); + TensorWrapper workspace, barrier; + + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + + // This call populates workspace and barrier tensors with the required config + nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + + // Fill workspace and barrier + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); + auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); + + // Actual call to fwd kernel + nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + + return {ln_out, rsigma}; } std::vector te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x, const paddle::Tensor &rsigma, const paddle::Tensor &gamma, int64_t sm_margin, bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, - "zero_centered_gamma is not supported yet for RMSNorm."); - auto dx = paddle::empty_like(x, x.dtype(), x.place()); - auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - - TensorWrapper workspace, barrier, dgamma_part; - - auto dz_cu = MakeNvteTensor(dz); - auto x_cu = MakeNvteTensor(x); - auto rsigma_cu = MakeNvteTensor(rsigma); - auto gamma_cu = MakeNvteTensor(gamma); - auto dx_cu = MakeNvteTensor(dx); - auto dgamma_cu = MakeNvteTensor(dgamma); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates tensors with the required config. - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, - workspace.data(), barrier.data()); - - // Alloc space for Tensors. - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); - auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); - - // Actual call to bwd kernel. - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, - workspace.data(), barrier.data()); - - return {dx, dgamma}; + NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); + auto dx = paddle::empty_like(x, x.dtype(), x.place()); + auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); + + TensorWrapper workspace, barrier, dgamma_part; + + auto dz_cu = MakeNvteTensor(dz); + auto x_cu = MakeNvteTensor(x); + auto rsigma_cu = MakeNvteTensor(rsigma); + auto gamma_cu = MakeNvteTensor(gamma); + auto dx_cu = MakeNvteTensor(dx); + auto dgamma_cu = MakeNvteTensor(dgamma); + + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + + // This call populates tensors with the required config. + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, + workspace.data(), barrier.data()); + + // Alloc space for Tensors. + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); + auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); + auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); + dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); + + // Actual call to bwd kernel. + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, + workspace.data(), barrier.data()); + + return {dx, dgamma}; } __global__ void set_rng_state(std::pair seed_offset, int64_t *rng_state_ptr) { - rng_state_ptr[0] = static_cast(seed_offset.first); - rng_state_ptr[1] = static_cast(seed_offset.second); + rng_state_ptr[0] = static_cast(seed_offset.first); + rng_state_ptr[1] = static_cast(seed_offset.second); } void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, @@ -605,75 +596,74 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor float p_dropout, const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, const int64_t qkv_type, int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_QKV = MakeNvteTensor(QKV); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor(O); - } else { // TODO: support fp8 - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // extract random number generator seed and offset - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(QKV.place()); - auto gen_cuda = dev_ctx->GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - set_rng_state<<<1, 1, 0, QKV.stream()>>>(seed_offset, static_cast(rng_state.data())); - - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), - max_seqlen, is_training, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = - reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), - max_seqlen, is_training, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + if (is_training && !softmax_aux) { + NVTE_ERROR("softmax_aux must be provided when training. \n"); + } + + auto qkv_dtype = Int2NvteDType(qkv_type); + // construct NVTE tensors + TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; + if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { + // BF16 or FP16 + te_QKV = MakeNvteTensor(QKV); + te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); + te_O = MakeNvteTensor(O); + } else { // TODO: support fp8 + NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); + } + if ((bias_type != "no_bias") && Bias) { + auto bias_shape = Bias->shape(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); + } + te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // extract random number generator seed and offset + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(QKV.place()); + auto gen_cuda = dev_ctx->GetGenerator(); + auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); + set_rng_state<<<1, 1, 0, QKV.stream()>>>(seed_offset, static_cast(rng_state.data())); + + auto te_rng_state = MakeNvteTensor(rng_state); + + // create auxiliary output tensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd_qkvpacked( + te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, + te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, + is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), QKV.stream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + + auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); + output_s->data.dptr = GetOptionalDataPtr(softmax_aux); + + // execute the kernel + nvte_fused_attn_fwd_qkvpacked( + te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, + te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, + is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), QKV.stream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); } // fused attention BWD with packed QKV @@ -687,77 +677,77 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor int64_t max_seqlen, float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, int64_t qkv_type) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_QKV = MakeNvteTensor(QKV); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQKV = MakeNvteTensor(dQKV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; // 1. softmax_aux 2. rng_state - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen), static_cast(max_seqlen)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens; - te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - max_seqlen, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - max_seqlen, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + TensorWrapper te_dBias; + if (bias_type != "no_bias" && dBias) { + auto bias_shape = dBias->shape(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); + } + + auto qkv_dtype = Int2NvteDType(qkv_type); + // construct NVTE tensors + TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; + if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { + // BF16 or FP16 + te_QKV = MakeNvteTensor(QKV); + te_O = MakeNvteTensor(O); + te_dO = MakeNvteTensor(dO); + te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); + te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); + te_dQKV = MakeNvteTensor(dQKV); + } else { + NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); + } + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // convert auxiliary tensors from forward into NVTETensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + nvte_aux_tensor_pack.size = 2; // 1. softmax_aux 2. rng_state + auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); + auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); + output_s->data.shape = + std::vector({static_cast(b), static_cast(h), + static_cast(max_seqlen), static_cast(max_seqlen)}); + output_s->data.dptr = const_cast(softmax_aux.data()); + fwd_rng_state->data.shape = std::vector({2}); + fwd_rng_state->data.dptr = const_cast(rng_state.data()); + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens; + te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); + + // create workspace + TensorWrapper workspace; + + auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, + te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), + QKV.stream()); + + // allocate memory for workspace + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, + te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), + QKV.stream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); } void te_fused_attn_fwd_kvpacked( @@ -770,91 +760,88 @@ void te_fused_attn_fwd_kvpacked( int64_t max_seqlen_q, int64_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, const int64_t qkv_type, int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor( - Q.data(), - {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, - qkv_dtype); - te_KV = MakeNvteTensor( - KV.data(), - {static_cast(total_seqs_kv), 2, static_cast(h), static_cast(d)}, - qkv_dtype); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor( - O.data(), - {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, - qkv_dtype); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); - auto gen_cuda = dev_ctx->GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast(rng_state.data())); - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, - max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum, - bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = - reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, - max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum, - bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + if (is_training && !softmax_aux) { + NVTE_ERROR("softmax_aux must be provided when training. \n"); + } + + auto qkv_dtype = Int2NvteDType(qkv_type); + + // construct NVTE tensors + TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; + if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { + // BF16 or FP16 + te_Q = MakeNvteTensor( + Q.data(), + {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, + qkv_dtype); + te_KV = MakeNvteTensor( + KV.data(), + {static_cast(total_seqs_kv), 2, static_cast(h), static_cast(d)}, + qkv_dtype); + te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); + te_O = MakeNvteTensor( + O.data(), + {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, + qkv_dtype); + } else { + NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); + } + + if ((bias_type != "no_bias") && Bias) { + auto bias_shape = Bias->shape(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); + } + + te_cu_seqlens_q = + MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); + te_cu_seqlens_kv = + MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); + auto gen_cuda = dev_ctx->GetGenerator(); + auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); + set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast(rng_state.data())); + auto te_rng_state = MakeNvteTensor(rng_state); + + // create auxiliary output tensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd_kvpacked( + te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), + te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + + auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); + output_s->data.dptr = GetOptionalDataPtr(softmax_aux); + + // execute the kernel + nvte_fused_attn_fwd_kvpacked( + te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), + te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); } // fused attention BWD with packed KV @@ -871,84 +858,84 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, int64_t qkv_type) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_KV = MakeNvteTensor(KV); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQ = MakeNvteTensor(dQ); - te_dKV = MakeNvteTensor(dKV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen_q), - static_cast(max_seqlen_kv)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + TensorWrapper te_dBias; + if (bias_type != "no_bias" && dBias) { + auto bias_shape = dBias->shape(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); + } + + auto qkv_dtype = Int2NvteDType(qkv_type); + // construct NVTE tensors + TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; + if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { + // BF16 or FP16 + te_Q = MakeNvteTensor(Q); + te_KV = MakeNvteTensor(KV); + te_O = MakeNvteTensor(O); + te_dO = MakeNvteTensor(dO); + te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); + te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); + te_dQ = MakeNvteTensor(dQ); + te_dKV = MakeNvteTensor(dKV); + } else { + NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); + } + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // convert auxiliary tensors from forward into NVTETensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + nvte_aux_tensor_pack.size = 2; + auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); + auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); + output_s->data.shape = + std::vector({static_cast(b), static_cast(h), + static_cast(max_seqlen_q), static_cast(max_seqlen_kv)}); + output_s->data.dptr = const_cast(softmax_aux.data()); + fwd_rng_state->data.shape = std::vector({2}); + fwd_rng_state->data.dptr = const_cast(rng_state.data()); + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + te_cu_seqlens_q = + MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); + te_cu_seqlens_kv = + MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); + + // create workspace + TensorWrapper workspace; + + auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, + bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); + + // allocate memory for workspace + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, + bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); } void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, @@ -962,83 +949,82 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, const int64_t qkv_type, int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_K = MakeNvteTensor(K); - te_V = MakeNvteTensor(V); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor(O); - } else { // TODO: support fp8 - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // extract random number generator seed and offset - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); - auto gen_cuda = dev_ctx->GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast(rng_state.data())); - - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), - te_rng_state.data(), max_seqlen_q, max_seqlen_kv, - is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, workspace.data(), Q.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = - reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), - te_rng_state.data(), max_seqlen_q, max_seqlen_kv, - is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, workspace.data(), Q.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + if (is_training && !softmax_aux) { + NVTE_ERROR("softmax_aux must be provided when training. \n"); + } + + auto qkv_dtype = Int2NvteDType(qkv_type); + // construct NVTE tensors + TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; + if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { + // BF16 or FP16 + te_Q = MakeNvteTensor(Q); + te_K = MakeNvteTensor(K); + te_V = MakeNvteTensor(V); + te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); + te_O = MakeNvteTensor(O); + } else { // TODO: support fp8 + NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); + } + if ((bias_type != "no_bias") && Bias) { + auto bias_shape = Bias->shape(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); + } + te_cu_seqlens_q = + MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); + te_cu_seqlens_kv = + MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // extract random number generator seed and offset + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); + auto gen_cuda = dev_ctx->GetGenerator(); + auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); + set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast(rng_state.data())); + + auto te_rng_state = MakeNvteTensor(rng_state); + + // create auxiliary output tensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), + Q.stream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); + + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + + auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); + output_s->data.dptr = GetOptionalDataPtr(softmax_aux); + + // execute the kernel + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), + Q.stream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); } void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, @@ -1054,236 +1040,234 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, int64_t qkv_type) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_K = MakeNvteTensor(K); - te_V = MakeNvteTensor(V); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQ = MakeNvteTensor(dQ); - te_dK = MakeNvteTensor(dK); - te_dV = MakeNvteTensor(dV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen_q), - static_cast(max_seqlen_kv)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), - Q.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), - Q.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + TensorWrapper te_dBias; + if (bias_type != "no_bias" && dBias) { + auto bias_shape = dBias->shape(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); + } + + auto qkv_dtype = Int2NvteDType(qkv_type); + // construct NVTE tensors + TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { + // BF16 or FP16 + te_Q = MakeNvteTensor(Q); + te_K = MakeNvteTensor(K); + te_V = MakeNvteTensor(V); + te_O = MakeNvteTensor(O); + te_dO = MakeNvteTensor(dO); + te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); + te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); + te_dQ = MakeNvteTensor(dQ); + te_dK = MakeNvteTensor(dK); + te_dV = MakeNvteTensor(dV); + } else { + NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); + } + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // convert auxiliary tensors from forward into NVTETensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + nvte_aux_tensor_pack.size = 2; + auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); + auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); + output_s->data.shape = + std::vector({static_cast(b), static_cast(h), + static_cast(max_seqlen_q), static_cast(max_seqlen_kv)}); + output_s->data.dptr = const_cast(softmax_aux.data()); + fwd_rng_state->data.shape = std::vector({2}); + fwd_rng_state->data.dptr = const_cast(rng_state.data()); + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + te_cu_seqlens_q = + MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); + te_cu_seqlens_kv = + MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); + + // create workspace + TensorWrapper workspace; + + auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), + Q.stream()); + + // allocate memory for workspace + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), + dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), + Q.stream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); } std::vector te_scaled_softmax_forward(const paddle::Tensor &input, float scale_factor) { - NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK((input.dtype() == paddle::DataType::FLOAT16) || - (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); + NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); + NVTE_CHECK( + (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), + "Only fp16 and bf16 are supported"); - const int batches = input.shape()[0]; - const int attn_heads = input.shape()[1]; - const int query_seq_len = input.shape()[2]; - const int key_seq_len = input.shape()[3]; + const int batches = input.shape()[0]; + const int attn_heads = input.shape()[1]; + const int query_seq_len = input.shape()[2]; + const int key_seq_len = input.shape()[3]; - NVTE_CHECK(key_seq_len <= 4096); - NVTE_CHECK(query_seq_len > 1); + NVTE_CHECK(key_seq_len <= 4096); + NVTE_CHECK(query_seq_len > 1); - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); + // Output + auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - auto input_cu = MakeNvteTensor(input); - auto softmax_results_cu = MakeNvteTensor(softmax_results); + auto input_cu = MakeNvteTensor(input); + auto softmax_results_cu = MakeNvteTensor(softmax_results); - nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, - input.stream()); + nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, + input.stream()); - return {softmax_results}; + return {softmax_results}; } void te_scaled_softmax_backward(paddle::Tensor &output_grads, // NOLINT const paddle::Tensor &softmax_results, float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, softmax_results.stream()); + NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); + NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); + + NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || + (output_grads.dtype() == paddle::DataType::BFLOAT16), + "Only fp16 and bf16 are supported"); + NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || + (softmax_results.dtype() == paddle::DataType::BFLOAT16), + "Only fp16 and bf16 are supported"); + + auto output_grads_cu = MakeNvteTensor(output_grads); + auto softmax_results_cu = MakeNvteTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), + output_grads_cu.data(), scale_factor, softmax_results.stream()); } std::vector te_scaled_masked_softmax_forward(const paddle::Tensor &input, const paddle::Tensor &mask, float scale_factor) { - NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(mask.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK((input.dtype() == paddle::DataType::FLOAT16) || - (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int batches = input.shape()[0]; - const int pad_batches = mask.shape()[0]; - const int attn_heads = input.shape()[1]; - const int query_seq_len = input.shape()[2]; - const int key_seq_len = input.shape()[3]; - - NVTE_CHECK(key_seq_len <= 4096); - NVTE_CHECK(query_seq_len > 1); - NVTE_CHECK(pad_batches == 1 || pad_batches == batches); - NVTE_CHECK(mask.shape()[1] == 1); - NVTE_CHECK(mask.shape()[2] == query_seq_len); - NVTE_CHECK(mask.shape()[3] == key_seq_len); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto mask_cu = MakeNvteTensor(mask); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(), - scale_factor, input.stream()); - - return {softmax_results}; + NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); + NVTE_CHECK(mask.shape().size() == 4, "expected 4D tensor"); + NVTE_CHECK( + (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), + "Only fp16 and bf16 are supported"); + + const int batches = input.shape()[0]; + const int pad_batches = mask.shape()[0]; + const int attn_heads = input.shape()[1]; + const int query_seq_len = input.shape()[2]; + const int key_seq_len = input.shape()[3]; + + NVTE_CHECK(key_seq_len <= 4096); + NVTE_CHECK(query_seq_len > 1); + NVTE_CHECK(pad_batches == 1 || pad_batches == batches); + NVTE_CHECK(mask.shape()[1] == 1); + NVTE_CHECK(mask.shape()[2] == query_seq_len); + NVTE_CHECK(mask.shape()[3] == key_seq_len); + + // Output + auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); + + auto input_cu = MakeNvteTensor(input); + auto mask_cu = MakeNvteTensor(mask); + auto softmax_results_cu = MakeNvteTensor(softmax_results); + + nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(), + scale_factor, input.stream()); + + return {softmax_results}; } void te_scaled_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT const paddle::Tensor &softmax_results, float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, softmax_results.stream()); + NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); + NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); + + NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || + (output_grads.dtype() == paddle::DataType::BFLOAT16), + "Only fp16 and bf16 are supported"); + NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || + (softmax_results.dtype() == paddle::DataType::BFLOAT16), + "Only fp16 and bf16 are supported"); + + auto output_grads_cu = MakeNvteTensor(output_grads); + auto softmax_results_cu = MakeNvteTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), + output_grads_cu.data(), scale_factor, softmax_results.stream()); } std::vector te_scaled_upper_triang_masked_softmax_forward( const paddle::Tensor &input, float scale_factor) { - NVTE_CHECK(input.shape().size() == 3, "expected 3D tensor"); - NVTE_CHECK((input.dtype() == paddle::DataType::FLOAT16) || - (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); + NVTE_CHECK(input.shape().size() == 3, "expected 3D tensor"); + NVTE_CHECK( + (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), + "Only fp16 and bf16 are supported"); - const int attn_batches = input.shape()[0]; - const int seq_len = input.shape()[1]; - NVTE_CHECK(seq_len <= 2048); + const int attn_batches = input.shape()[0]; + const int seq_len = input.shape()[1]; + NVTE_CHECK(seq_len <= 2048); - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); + // Output + auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - auto input_cu = MakeNvteTensor(input); - auto softmax_results_cu = MakeNvteTensor(softmax_results); + auto input_cu = MakeNvteTensor(input); + auto softmax_results_cu = MakeNvteTensor(softmax_results); - nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), - scale_factor, input.stream()); + nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), + scale_factor, input.stream()); - return {softmax_results}; + return {softmax_results}; } void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT const paddle::Tensor &softmax_results, float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 3, "expected 3D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 3, "expected 3D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK(output_grads.shape()[1] == output_grads.shape()[2]); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_upper_triang_masked_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, - softmax_results.stream()); + NVTE_CHECK(output_grads.shape().size() == 3, "expected 3D tensor"); + NVTE_CHECK(softmax_results.shape().size() == 3, "expected 3D tensor"); + + NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || + (output_grads.dtype() == paddle::DataType::BFLOAT16), + "Only fp16 and bf16 are supported"); + NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || + (softmax_results.dtype() == paddle::DataType::BFLOAT16), + "Only fp16 and bf16 are supported"); + NVTE_CHECK(output_grads.shape()[1] == output_grads.shape()[2]); + + auto output_grads_cu = MakeNvteTensor(output_grads); + auto softmax_results_cu = MakeNvteTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_upper_triang_masked_softmax_backward( + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, + softmax_results.stream()); } constexpr int BLOCK_SIZE = 512; @@ -1291,115 +1275,103 @@ constexpr int BLOCK_SIZE = 512; void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT paddle::Tensor &scale, // NOLINT paddle::Tensor &scale_inv, // NOLINT - const paddle::Tensor &non_weight_mask, - int64_t fp8_dtype, - float margin, - const std::string &amax_compute) { + const paddle::Tensor &non_weight_mask, int64_t fp8_dtype, + float margin, const std::string &amax_compute) { auto amax_history_ = MakeNvteTensor(amax_history); auto scale_ = MakeNvteTensor(scale); auto scale_inv_ = MakeNvteTensor(scale_inv); const auto non_weight_mask_ = MakeNvteTensor(non_weight_mask); nvte_delayed_scaling_recipe_amax_and_scale_update( - amax_history_.data(), - scale_.data(), - scale_inv_.data(), - non_weight_mask_.data(), - amax_history_.data(), - scale_.data(), - scale_inv_.data(), - amax_compute.c_str(), - static_cast(fp8_dtype), - margin, - amax_history.stream()); + amax_history_.data(), scale_.data(), scale_inv_.data(), non_weight_mask_.data(), + amax_history_.data(), scale_.data(), scale_inv_.data(), amax_compute.c_str(), + static_cast(fp8_dtype), margin, amax_history.stream()); } void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT const paddle::Tensor &amax) { - // Copy amax to history[0] - NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(), - amax.numel() * SizeOf(amax.dtype()), cudaMemcpyDeviceToDevice, - amax.stream())); + // Copy amax to history[0] + NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(), amax.numel() * SizeOf(amax.dtype()), + cudaMemcpyDeviceToDevice, amax.stream())); } __global__ __launch_bounds__(BLOCK_SIZE) void mask_to_actual_seqlens_kernel( const bool *mask, int32_t *q_actual_seqlen, int32_t *kv_actual_seqlen, int q_seqlen, int kv_seqlen, bool need_kv) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage q_smem; - __shared__ typename BlockReduce::TempStorage kv_smem; - unsigned int tid = threadIdx.x; - unsigned int batch_offset = blockIdx.x * q_seqlen * kv_seqlen; - - // load mask, convert to 1/0, do accumulation - int q = 0, kv = 0; - for (unsigned int q_idx = tid * kv_seqlen; q_idx < q_seqlen * kv_seqlen; - q_idx += BLOCK_SIZE * kv_seqlen) { - q += (mask[q_idx + batch_offset] ? 0 : 1); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage q_smem; + __shared__ typename BlockReduce::TempStorage kv_smem; + unsigned int tid = threadIdx.x; + unsigned int batch_offset = blockIdx.x * q_seqlen * kv_seqlen; + + // load mask, convert to 1/0, do accumulation + int q = 0, kv = 0; + for (unsigned int q_idx = tid * kv_seqlen; q_idx < q_seqlen * kv_seqlen; + q_idx += BLOCK_SIZE * kv_seqlen) { + q += (mask[q_idx + batch_offset] ? 0 : 1); + } + + if (need_kv) { + for (unsigned int kv_idx = tid; kv_idx < kv_seqlen; kv_idx += BLOCK_SIZE) { + kv += (mask[kv_idx + batch_offset] ? 0 : 1); } + } + __syncthreads(); + + // compute cub::BlockReduce + int q_sum, kv_sum; + q_sum = BlockReduce(q_smem).Sum(q); + if (need_kv) kv_sum = BlockReduce(kv_smem).Sum(kv); + // write result for this block to global mem + if (tid == 0) { + q_actual_seqlen[blockIdx.x + 1] = q_sum; if (need_kv) { - for (unsigned int kv_idx = tid; kv_idx < kv_seqlen; kv_idx += BLOCK_SIZE) { - kv += (mask[kv_idx + batch_offset] ? 0 : 1); - } - } - __syncthreads(); - - // compute cub::BlockReduce - int q_sum, kv_sum; - q_sum = BlockReduce(q_smem).Sum(q); - if (need_kv) kv_sum = BlockReduce(kv_smem).Sum(kv); - - // write result for this block to global mem - if (tid == 0) { - q_actual_seqlen[blockIdx.x + 1] = q_sum; - if (need_kv) { - kv_actual_seqlen[blockIdx.x + 1] = kv_sum; - } + kv_actual_seqlen[blockIdx.x + 1] = kv_sum; } + } } __global__ __launch_bounds__(BLOCK_SIZE) void block_prefix_sum_inplace(int32_t *x, int n) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage smem; - // +1 to ignore the first element - int i = blockIdx.x * blockDim.x + threadIdx.x + 1; - - // load data - int32_t thread_data[1]; - thread_data[0] = i < n ? x[i] : 0; - __syncthreads(); - - // CUB block prefix sum - BlockScan(smem).InclusiveSum(thread_data, thread_data); - __syncthreads(); - - // write result - if (i < n) { - x[i] = thread_data[0]; - } + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage smem; + // +1 to ignore the first element + int i = blockIdx.x * blockDim.x + threadIdx.x + 1; + + // load data + int32_t thread_data[1]; + thread_data[0] = i < n ? x[i] : 0; + __syncthreads(); + + // CUB block prefix sum + BlockScan(smem).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + // write result + if (i < n) { + x[i] = thread_data[0]; + } } void mask_to_cu_seqlens(const paddle::Tensor &mask, paddle::Tensor &q_cu_seqlen, // NOLINT paddle::optional &kv_cu_seqlen, // NOLINT int q_seqlen, int kv_seqlen, bool need_kv) { - if (need_kv) { - NVTE_CHECK(GetOptionalDataPtr(kv_cu_seqlen) != nullptr, - "kv_cu_seqlen must be provided when need_kv is true"); - } - mask_to_actual_seqlens_kernel<<>>( - mask.data(), q_cu_seqlen.data(), - reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), q_seqlen, kv_seqlen, - need_kv); - // q_cu_seqlen shape: [bs+1], assume bs is not too large (<=512), so we can use a single block - // to do prefix sum - NVTE_CHECK(q_cu_seqlen.numel() - 1 <= BLOCK_SIZE, "batch size too large, kernel may fail"); - block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>(q_cu_seqlen.data(), - q_cu_seqlen.numel()); - if (need_kv) { - block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>( - reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), kv_cu_seqlen->numel()); - } + if (need_kv) { + NVTE_CHECK(GetOptionalDataPtr(kv_cu_seqlen) != nullptr, + "kv_cu_seqlen must be provided when need_kv is true"); + } + mask_to_actual_seqlens_kernel<<>>( + mask.data(), q_cu_seqlen.data(), + reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), q_seqlen, kv_seqlen, need_kv); + // q_cu_seqlen shape: [bs+1], assume bs is not too large (<=512), so we can use a single block + // to do prefix sum + NVTE_CHECK(q_cu_seqlen.numel() - 1 <= BLOCK_SIZE, "batch size too large, kernel may fail"); + block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>(q_cu_seqlen.data(), + q_cu_seqlen.numel()); + if (need_kv) { + block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>( + reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), kv_cu_seqlen->numel()); + } } } // namespace paddle_ext diff --git a/transformer_engine/paddle/csrc/extensions.cu b/transformer_engine/paddle/csrc/extensions.cu index cb0183a1742d2aa38d1e3f6726a67a3c06e701b2..128b7e28564c4e288b605fbc046a467627635f4d 100644 --- a/transformer_engine/paddle/csrc/extensions.cu +++ b/transformer_engine/paddle/csrc/extensions.cu @@ -12,52 +12,52 @@ namespace paddle_ext { size_t get_cublasLt_version() { return cublasLtGetVersion(); } PYBIND11_MODULE(transformer_engine_paddle, m) { - // Misc - m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); - m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); - m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string"); - // Data structures - py::enum_(m, "DType", py::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - py::enum_(m, "NVTE_Bias_Type") - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - py::enum_(m, "NVTE_Mask_Type") - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); - - py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - py::enum_(m, "NVTE_Fused_Attn_Backend", py::module_local()) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); + // Misc + m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); + m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); + m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string"); + // Data structures + py::enum_(m, "DType", py::module_local()) + .value("kByte", DType::kByte) + .value("kInt32", DType::kInt32) + .value("kFloat32", DType::kFloat32) + .value("kFloat16", DType::kFloat16) + .value("kBFloat16", DType::kBFloat16) + .value("kFloat8E4M3", DType::kFloat8E4M3) + .value("kFloat8E5M2", DType::kFloat8E5M2); + + py::enum_(m, "NVTE_Bias_Type") + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + + py::enum_(m, "NVTE_Mask_Type") + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); + + py::enum_(m, "NVTE_QKV_Layout") + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); + + py::enum_(m, "NVTE_Fused_Attn_Backend", py::module_local()) + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); } } // namespace paddle_ext } // namespace transformer_engine diff --git a/transformer_engine/paddle/distributed.py b/transformer_engine/paddle/distributed.py index 24455e1146f2253956a1110bb7b15f0aca329838..e6be9cb1c76f33e72209af3e99d937f2a1e40176 100644 --- a/transformer_engine/paddle/distributed.py +++ b/transformer_engine/paddle/distributed.py @@ -17,26 +17,25 @@ from paddle.distributed.fleet.layers.mpu import mp_ops from .constants import dist_group_type _weight_split_axis = { - 'transformer_engine': { - 'row': 1, - 'column': 0 - }, - 'paddle': { - 'row': 0, - 'column': 1 - } + "transformer_engine": {"row": 1, "column": 0}, + "paddle": {"row": 0, "column": 1}, } -def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None], - enable_tp: bool = True) -> Tuple[Union[dist_group_type, None], int]: +def get_tp_group_and_world_size( + tp_group: Union[dist_group_type, None], enable_tp: bool = True +) -> Tuple[Union[dist_group_type, None], int]: """Get TP group and world size using Fleet API""" if not (paddle.distributed.is_initialized() and enable_tp): return None, 1 - model_parallel_group = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() - if tp_group is None else tp_group) - world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() - if tp_group is None else tp_group.nranks) + model_parallel_group = ( + tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() if tp_group is None else tp_group + ) + world_size = ( + tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() + if tp_group is None + else tp_group.nranks + ) """ When using TP, the NCCL communication needs to be scheduled before the GEMM for a guaranteed overlap. From the host side @@ -47,8 +46,10 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None], """ num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) if num_cuda_work_queues != 1: - warnings.warn("To guarantee overlapping TP and SP collectives with the backward" - "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1") + warnings.warn( + "To guarantee overlapping TP and SP collectives with the backward" + "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1" + ) return model_parallel_group, world_size @@ -73,8 +74,9 @@ def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> tensor.split_axis = axis -def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, - parallel_mode: Optional[str], backend: str) -> None: +def set_weight_tensor_dist_attr( + tensor: paddle.Tensor, is_parallel: bool, parallel_mode: Optional[str], backend: str +) -> None: """Set distributed attributes for the weight tensor""" if not is_parallel or parallel_mode is None: return @@ -149,17 +151,15 @@ def reduce_scatter( parallelism = tp_group.nranks output_shape = input_.shape - assert ( - input_.shape[0] % parallelism == 0 - ), f"Input sequence length {input_.shape[0]} can't be divided " \ + assert input_.shape[0] % parallelism == 0, ( + f"Input sequence length {input_.shape[0]} can't be divided " f"exactly by sequence parallelism {parallelism}" + ) output_shape[0] = output_shape[0] // parallelism output = paddle.empty(shape=output_shape, dtype=input_.dtype) - wait_handle = paddle.distributed.stream.reduce_scatter(output, - input_, - op=paddle.distributed.ReduceOp.SUM, - group=tp_group, - sync_op=sync_op) + wait_handle = paddle.distributed.stream.reduce_scatter( + output, input_, op=paddle.distributed.ReduceOp.SUM, group=tp_group, sync_op=sync_op + ) if sync_op: return output, None return output, wait_handle diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py index 553355551d595460343fffba3b6d41df25de7370..856ce505238fa2f96e5130c37c4395e9b6646ece 100644 --- a/transformer_engine/paddle/fp8.py +++ b/transformer_engine/paddle/fp8.py @@ -15,7 +15,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format from .constants import dist_group_type from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer -__all__ = ['fp8_autocast'] +__all__ = ["fp8_autocast"] # FP8 support _is_fp8_available = None @@ -27,9 +27,9 @@ def _check_fp8_support() -> Tuple[bool, str]: # Check GPU arch arch = paddle.device.cuda.get_device_capability() - if arch >= (9, 0): # hopper and above + if arch >= (9, 0): # hopper and above return True, "" - if arch < (8, 9): # pre-ada + if arch < (8, 9): # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." # Special handling for Ada @@ -124,8 +124,13 @@ class FP8State: fp8_group: Optional[dist_group_type], ) -> None: """Called when entering 'fp8_autocast'""" - self.saved_states = (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe, - self._fp8_distributed_group, self._is_first_fp8_module) + self.saved_states = ( + self._fp8_enabled, + self._fp8_calibration, + self._fp8_recipe, + self._fp8_distributed_group, + self._is_first_fp8_module, + ) self._fp8_enabled = enabled self._fp8_calibration = calibrating @@ -140,8 +145,13 @@ class FP8State: def exit(self): """Called when exiting 'fp8_autocast'""" # Restore saved states - (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe, self._fp8_distributed_group, - self._is_first_fp8_module) = self.saved_states + ( + self._fp8_enabled, + self._fp8_calibration, + self._fp8_recipe, + self._fp8_distributed_group, + self._is_first_fp8_module, + ) = self.saved_states self._fp8_autocast_depth -= 1 @@ -214,8 +224,9 @@ def fp8_autocast( def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or (fp8_recipe.fp8_format == Format.HYBRID - and fprop_tensor): + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): return tex.DType.kFloat8E4M3 return tex.DType.kFloat8E5M2 @@ -241,14 +252,17 @@ def amax_and_scale_update( non_weight_mask=non_weight_mask, fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)), margin=float(fp8_meta["recipe"].margin), - amax_compute=amax_compute) + amax_compute=amax_compute, + ) else: - raise ValueError("We only support the fp8 recipe with 'max' or 'most_recent' " - "amax_compute_algo and default scaling_factor_compute_algo at this " - "moment.") + raise ValueError( + "We only support the fp8 recipe with 'max' or 'most_recent' " + "amax_compute_algo and default scaling_factor_compute_algo at this " + "moment." + ) -class FP8TensorMeta(): +class FP8TensorMeta: """Holds FP8 scaling and amax history for FP8 layers""" def __init__(self, is_forward: bool): @@ -281,20 +295,22 @@ class FP8TensorMeta(): self.amax_history = self.amax_history[:amax_history_len] elif amax_history_len > curr_len: extra_rows = amax_history_len - curr_len - self.amax_history = paddle.concat([ - self.amax_history, - paddle.zeros((extra_rows, num_fp8_tensors), dtype='float32') - ], - axis=0) + self.amax_history = paddle.concat( + [ + self.amax_history, + paddle.zeros((extra_rows, num_fp8_tensors), dtype="float32"), + ], + axis=0, + ) return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd - num_fp8_tensors = (num_gemms * 3 if self.is_forward else num_gemms * 2) + num_fp8_tensors = num_gemms * 3 if self.is_forward else num_gemms * 2 - self.scale = paddle.ones(num_fp8_tensors, dtype='float32') - self.scale_inv = paddle.ones(num_fp8_tensors, dtype='float32') - self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype='float32') + self.scale = paddle.ones(num_fp8_tensors, dtype="float32") + self.scale_inv = paddle.ones(num_fp8_tensors, dtype="float32") + self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype="float32") self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms) self.is_initialized = True @@ -303,16 +319,16 @@ class FP8TensorMeta(): """Convert FP8 meta tensors to numpy.""" assert self.is_initialized, "FP8TensorMeta is not initialized yet." return { - 'scale': self.scale.numpy(), - 'scale_inv': self.scale_inv.numpy(), - 'amax_history': self.amax_history.numpy(), + "scale": self.scale.numpy(), + "scale_inv": self.scale_inv.numpy(), + "amax_history": self.amax_history.numpy(), } def from_numpy(self, data: Dict[str, np.array]): """Set FP8 meta tensors from numpy""" - self.scale = paddle.to_tensor(data['scale']) - self.scale_inv = paddle.to_tensor(data['scale_inv']) - self.amax_history = paddle.to_tensor(data['amax_history']) + self.scale = paddle.to_tensor(data["scale"]) + self.scale_inv = paddle.to_tensor(data["scale_inv"]) + self.amax_history = paddle.to_tensor(data["amax_history"]) num_fp8_tensors = self.scale.shape[0] num_gemms = num_fp8_tensors // 3 if self.is_forward else num_fp8_tensors // 2 diff --git a/transformer_engine/paddle/fp8_buffer.py b/transformer_engine/paddle/fp8_buffer.py index bcf4be95acf6c4b2467590c55c47cebc96b05f53..0ef88a85b4281bca3f9b48f25bd01760138b874d 100644 --- a/transformer_engine/paddle/fp8_buffer.py +++ b/transformer_engine/paddle/fp8_buffer.py @@ -49,7 +49,7 @@ class FP8MetaBufferBase(ABC): def _execute_deletion(self) -> None: """Delete the key from global amax buffer.""" - if (self._buffer_delete_key is not None and self._buffer_delete_key in self._data): + if self._buffer_delete_key is not None and self._buffer_delete_key in self._data: del self._data[self._buffer_delete_key] def _wait_handle_and_split( @@ -137,11 +137,12 @@ class FP8MetaBufferBase(ABC): fp8_meta[buffer_position_key] = len(self._data[buffer_key]) - 1 # Catch incorrect fp8_autocast usage. - assert fp8_meta[buffer_position_key] == len(self._data[buffer_key]) - 1, \ - "Same module is being invoked more than once inside an `fp8_autocast` " \ - "region when using FP8 with amax reduction. This behavior is currently " \ - "unsupported. For more details and correct usage, please see " \ + assert fp8_meta[buffer_position_key] == len(self._data[buffer_key]) - 1, ( + "Same module is being invoked more than once inside an `fp8_autocast` " + "region when using FP8 with amax reduction. This behavior is currently " + "unsupported. For more details and correct usage, please see " "https://github.com/NVIDIA/TransformerEngine/pull/93." + ) def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None: """Populate current amax with the correct location from buffer.""" @@ -156,7 +157,8 @@ class FP8MetaBufferBase(ABC): # Copy amax to amax_history[0] tex.update_latest_amax_history_inplace( _history=fp8_meta[fp8_meta_tensor_key].amax_history, - amax=self._data[amax_buffer_key][fp8_meta[buffer_position_key]]) + amax=self._data[amax_buffer_key][fp8_meta[buffer_position_key]], + ) def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None: """Delete this amax key from global buffer during autocast end.""" @@ -171,7 +173,7 @@ class FP8MetaBufferBase(ABC): def wait(self) -> None: """Wait for reduced amax to be available in buffer.""" if self._amax_reduce_wait_func is not None: - self._amax_reduce_wait_func() # pylint: disable=not-callable + self._amax_reduce_wait_func() # pylint: disable=not-callable self._amax_reduce_wait_func = None def to_numpy(self) -> Dict[str, List[np.array]]: @@ -224,7 +226,7 @@ class FP8MetaFwdBuffer(FP8MetaBufferBase): Called at FP8 autocast end. Performs AMAX reduction and delete unused buffer entries. """ - if hasattr(self, '_amax_global_reduce_func') and callable(self._amax_global_reduce_func): + if hasattr(self, "_amax_global_reduce_func") and callable(self._amax_global_reduce_func): self._amax_reduce_wait_func = self._amax_global_reduce_func() self._execute_deletion() @@ -270,7 +272,7 @@ class FP8RecomputeBuffer: @staticmethod def get_buffer_position_key(): """Returns the key (in fp8_meta) for recompute buffer position""" - return 'recompute_buffer_pos' + return "recompute_buffer_pos" def stash_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None: """Stash the scaling factors and amaxes for recompute""" @@ -308,11 +310,13 @@ class FP8RecomputeBuffer: @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" - assert "updated_amax_history_fwd" in fp8_meta, "Recompute internal error." \ - " If you are not using recompute, please check if" \ - " the forward function is called from one of these functions: " \ - f"{RecomputeFunctionNames}. If so, consider change the function name " \ + assert "updated_amax_history_fwd" in fp8_meta, ( + "Recompute internal error." + " If you are not using recompute, please check if" + " the forward function is called from one of these functions: " + f"{RecomputeFunctionNames}. If so, consider change the function name " "or set NVTE_DISABLE_RECOMPUTE=1." + ) fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index 0340ccd21c519dc4986282935e8ced784f896b93..98e50b9e04790eafcd1b53e46dd485cba6a46dd4 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -10,6 +10,7 @@ from typing import Optional, Tuple, Union import paddle import paddle.nn.functional as F + try: from paddle.incubate.nn.functional import fused_rotary_position_embedding except ImportError: @@ -19,8 +20,14 @@ from transformer_engine import transformer_engine_paddle as tex from .layernorm_linear import LayerNormLinear from .linear import Linear from .softmax import FusedScaleMaskSoftmax -from ..constants import (AttnTypes, TE_DType, AttnBiasType, AttnMaskType, FusedAttnBackend, - dist_group_type) +from ..constants import ( + AttnTypes, + TE_DType, + AttnBiasType, + AttnMaskType, + FusedAttnBackend, + dist_group_type, +) from ..cpp_extensions import ( fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked, @@ -72,8 +79,9 @@ class RotaryPositionEmbedding(paddle.nn.Layer): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings - self.inv_freq = 1.0 / (10000**(paddle.cast(paddle.arange(0, dim, 2), dtype='float32') / - self.dim)) + self.inv_freq = 1.0 / ( + 10000 ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / self.dim) + ) self._set_cos_sin_cache(seq_len=max_position_embeddings) def _set_cos_sin_cache(self, seq_len): @@ -104,9 +112,9 @@ class RotaryPositionEmbedding(paddle.nn.Layer): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return paddle.concat([-x2, x1], axis=-1) # shape is the same as x + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): @@ -114,13 +122,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): if position_ids is None: # Note: Only for LlamaForCausalLMPipe model pretraining - cos = cos[:, :q.shape[1], :, :] # [bs, seq_len, 1, dim] - sin = sin[:, :q.shape[1], :, :] # [bs, seq_len, 1, dim] + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] else: - cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] - sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -130,9 +138,22 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): """Function for FusedAttention with packed QKV input""" @staticmethod - def forward(ctx, qkv, cu_seqlens, attn_bias, max_seqlen, attn_scale, qkv_dtype, dropout_p, - set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training, - fused_attention_backend): + def forward( + ctx, + qkv, + cu_seqlens, + attn_bias, + max_seqlen, + attn_scale, + qkv_dtype, + dropout_p, + set_zero, + qkv_layout, + attn_bias_type, + attn_mask_type, + is_training, + fused_attention_backend, + ): """Forward function for FusedAttention with packed QKV input""" out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked( qkv, @@ -167,11 +188,23 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): def backward(ctx, d_out): """Backward function for FusedAttention with packed QKV input""" qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor() - dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out, softmax_aux, - ctx.fused_attention_backend, ctx.max_seqlen, - ctx.qkv_dtype, ctx.attn_scale, ctx.dropout_p, - ctx.set_zero, ctx.qkv_layout, ctx.attn_bias_type, - ctx.attn_mask_type) + dqkv, *rest = fused_attn_bwd_qkvpacked( + qkv, + cu_seqlens, + rng_state, + out, + d_out, + softmax_aux, + ctx.fused_attention_backend, + ctx.max_seqlen, + ctx.qkv_dtype, + ctx.attn_scale, + ctx.dropout_p, + ctx.set_zero, + ctx.qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ) # if no_bias, return dqkv if ctx.attn_bias_type == "no_bias": @@ -184,14 +217,44 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): """Function for FusedAttention with packed KV input""" @staticmethod - def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv, - attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type, - attn_mask_type, is_training, fused_attention_backend): + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_kv, + attn_bias, + max_seqlen_q, + max_seqlen_kv, + attn_scale, + qkv_dtype, + dropout_p, + set_zero, + qkv_layout, + attn_bias_type, + attn_mask_type, + is_training, + fused_attention_backend, + ): """Forward function for FusedAttention with packed KV input""" out, softmax_aux, rng_state = fused_attn_fwd_kvpacked( - q, kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, qkv_dtype, - fused_attention_backend, attn_bias, attn_scale, dropout_p, set_zero, qkv_layout, - attn_bias_type, attn_mask_type) + q, + kv, + cu_seqlens_q, + cu_seqlens_kv, + is_training, + max_seqlen_q, + max_seqlen_kv, + qkv_dtype, + fused_attention_backend, + attn_bias, + attn_scale, + dropout_p, + set_zero, + qkv_layout, + attn_bias_type, + attn_mask_type, + ) ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux) ctx.max_seqlen_q = max_seqlen_q @@ -211,12 +274,26 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): def backward(ctx, d_out): """Backward function for FusedAttention with packed KV input""" q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor() - dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out, - d_out, softmax_aux, ctx.fused_attention_backend, - ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype, - ctx.attn_scale, ctx.dropout_p, ctx.set_zero, - ctx.qkv_layout, ctx.attn_bias_type, - ctx.attn_mask_type) + dq, dkv, *rest = fused_attn_bwd_kvpacked( + q, + kv, + cu_seqlens_q, + cu_seqlens_kv, + rng_state, + out, + d_out, + softmax_aux, + ctx.fused_attention_backend, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ctx.qkv_dtype, + ctx.attn_scale, + ctx.dropout_p, + ctx.set_zero, + ctx.qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ) # if no_bias, return dq, dkv if ctx.attn_bias_type == "no_bias": @@ -229,15 +306,46 @@ class FusedAttnFunc(paddle.autograd.PyLayer): """Function for FusedAttention with separate Q, K, V tensors""" @staticmethod - def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv, - attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type, - attn_mask_type, is_training, fused_attention_backend): + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + attn_bias, + max_seqlen_q, + max_seqlen_kv, + attn_scale, + qkv_dtype, + dropout_p, + set_zero, + qkv_layout, + attn_bias_type, + attn_mask_type, + is_training, + fused_attention_backend, + ): """Forward function for FusedAttention with separate Q, K, V tensors""" - out, softmax_aux, rng_state = fused_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, - is_training, max_seqlen_q, max_seqlen_kv, - qkv_dtype, fused_attention_backend, attn_bias, - attn_scale, dropout_p, set_zero, qkv_layout, - attn_bias_type, attn_mask_type) + out, softmax_aux, rng_state = fused_attn_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + is_training, + max_seqlen_q, + max_seqlen_kv, + qkv_dtype, + fused_attention_backend, + attn_bias, + attn_scale, + dropout_p, + set_zero, + qkv_layout, + attn_bias_type, + attn_mask_type, + ) ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux) ctx.max_seqlen_q = max_seqlen_q @@ -257,11 +365,27 @@ class FusedAttnFunc(paddle.autograd.PyLayer): def backward(ctx, d_out): """Backward function for FusedAttention with separate Q, K, V tensors""" q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor() - dq, dk, dv, *rest = fused_attn_bwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, rng_state, out, - d_out, softmax_aux, ctx.fused_attention_backend, - ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype, - ctx.attn_scale, ctx.dropout_p, ctx.set_zero, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + dq, dk, dv, *rest = fused_attn_bwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + rng_state, + out, + d_out, + softmax_aux, + ctx.fused_attention_backend, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ctx.qkv_dtype, + ctx.attn_scale, + ctx.dropout_p, + ctx.set_zero, + ctx.qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ) # if no_bias, return dq, dk, dv if ctx.attn_bias_type == "no_bias": return (dq, dk, dv, None, None) @@ -306,15 +430,17 @@ class DotProductAttention(paddle.nn.Layer): backend to use for attention operation. """ - def __init__(self, - num_attention_heads: int, - kv_channels: int, - num_gqa_groups: Optional[int] = None, - attention_dropout: float = 0.1, - attn_mask_type: str = "causal", - attention_type: str = "self", - tp_size: int = 1, - backend: str = 'transformer_engine') -> None: + def __init__( + self, + num_attention_heads: int, + kv_channels: int, + num_gqa_groups: Optional[int] = None, + attention_dropout: float = 0.1, + attn_mask_type: str = "causal", + attention_type: str = "self", + tp_size: int = 1, + backend: str = "transformer_engine", + ) -> None: super().__init__() self.attn_mask_type = attn_mask_type @@ -324,7 +450,7 @@ class DotProductAttention(paddle.nn.Layer): self.hidden_size_per_attention_head = kv_channels self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.tp_size = tp_size - self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups) + self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups @@ -332,14 +458,14 @@ class DotProductAttention(paddle.nn.Layer): self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1"))) - if not self.use_fused_attention and backend == 'transformer_engine': + if not self.use_fused_attention and backend == "transformer_engine": warnings.warn("Fused attention is not enabled, falling back to Paddle backend") - self.backend = 'paddle' + self.backend = "paddle" - if self.backend != 'transformer_engine': - self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type, - attention_mask_func, - backend=self.backend) + if self.backend != "transformer_engine": + self.scale_mask_softmax = FusedScaleMaskSoftmax( + attn_mask_type, attention_mask_func, backend=self.backend + ) def forward( self, @@ -380,35 +506,53 @@ class DotProductAttention(paddle.nn.Layer): backend = self.backend - assert (key_layer.shape == value_layer.shape), "Keys and values must have the same shape!" - assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" + assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" + assert ( + key_layer.shape[-2] == self.num_gqa_groups_per_partition + ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" - if backend == 'transformer_engine': + if backend == "transformer_engine": max_s_q = query_layer.shape[1] max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1] self.fused_attention_backend = tex.get_fused_attn_backend( - TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], - tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type], - AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2], - key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], max_s_q, - max_s_kv, query_layer.shape[-1]) - - is_backend_avail = (self.fused_attention_backend in [ - FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"] - ]) + TE_DType[query_layer.dtype], + TE_DType[query_layer.dtype], + tex.get_nvte_qkv_layout(self.qkv_layout), + AttnBiasType[core_attention_bias_type], + AttnMaskType[self.attn_mask_type], + self.attention_dropout, + query_layer.shape[-2], + key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], + max_s_q, + max_s_kv, + query_layer.shape[-1], + ) + + is_backend_avail = self.fused_attention_backend in [ + FusedAttnBackend["F16_max512_seqlen"], + FusedAttnBackend["F16_arbitrary_seqlen"], + ] if is_backend_avail and self.use_fused_attention: - return self._te_forward(query_layer, key_layer, value_layer, attention_mask, - core_attention_bias_type, core_attention_bias, set_zero) + return self._te_forward( + query_layer, + key_layer, + value_layer, + attention_mask, + core_attention_bias_type, + core_attention_bias, + set_zero, + ) warnings.warn("Fused attention is not enabled, falling back to Paddle backend") - backend = 'paddle' - self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type, - attention_mask_func, - backend=backend) - if backend == 'paddle': + backend = "paddle" + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.attn_mask_type, attention_mask_func, backend=backend + ) + if backend == "paddle": if core_attention_bias_type != "no_bias": - warnings.warn("Paddle backend dot product attention does not support bias yet. " - "Bias will be ignored.") + warnings.warn( + "Paddle backend dot product attention does not support bias yet. " + "Bias will be ignored." + ) return self._pd_forward(query_layer, key_layer, value_layer, attention_mask) raise AttributeError(f"Backend {backend} is not supported.") @@ -425,45 +569,76 @@ class DotProductAttention(paddle.nn.Layer): if self.attention_type == "self": # self attention - q: [b, s, h, d] kv: None - assert (len(query_layer.shape) == 4 and len(key_layer.shape) == 4 - and len(value_layer.shape) - == 4), "q,k,v shape must be [b, s, h, d] for dot product self attention" + assert ( + len(query_layer.shape) == 4 + and len(key_layer.shape) == 4 + and len(value_layer.shape) == 4 + ), "q,k,v shape must be [b, s, h, d] for dot product self attention" max_seqlen = query_layer.shape[1] if self.attn_mask_type == "causal" or attention_mask is None: - cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1], - step=query_layer.shape[1], - dtype='int32') + cu_seqlens = paddle.arange( + 0, + (query_layer.shape[0] + 1) * query_layer.shape[1], + step=query_layer.shape[1], + dtype="int32", + ) else: cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False) qkv_dtype = TE_DType[query_layer.dtype] - output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens, - cu_seqlens, core_attention_bias, max_seqlen, max_seqlen, - 1.0 / self.norm_factor, qkv_dtype, - self.attention_dropout if self.training else 0.0, set_zero, - self.qkv_layout, core_attention_bias_type, - self.attn_mask_type, self.training, - self.fused_attention_backend) + output = FusedAttnFunc.apply( + query_layer, + key_layer, + value_layer, + cu_seqlens, + cu_seqlens, + core_attention_bias, + max_seqlen, + max_seqlen, + 1.0 / self.norm_factor, + qkv_dtype, + self.attention_dropout if self.training else 0.0, + set_zero, + self.qkv_layout, + core_attention_bias_type, + self.attn_mask_type, + self.training, + self.fused_attention_backend, + ) elif self.attention_type == "cross": # cross attention - q: [b, s_q, h, d] k,v: [b, s_kv, h, d] assert ( - len(query_layer.shape) == 4 and len(key_layer.shape) == 4 + len(query_layer.shape) == 4 + and len(key_layer.shape) == 4 and len(value_layer.shape) == 4 - ), "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" \ + ), ( + "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" "for dot product cross attention" - assert (attention_mask - is not None), "attention_mask must be provided for cross attention" + ) + assert attention_mask is not None, "attention_mask must be provided for cross attention" max_seqlen_q = query_layer.shape[1] max_seqlen_kv = key_layer.shape[1] cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True) qkv_dtype = TE_DType[query_layer.dtype] - output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens_q, - cu_seqlens_kv, core_attention_bias, max_seqlen_q, - max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype, - self.attention_dropout if self.training else 0.0, set_zero, - self.qkv_layout, core_attention_bias_type, - self.attn_mask_type, self.training, - self.fused_attention_backend) + output = FusedAttnFunc.apply( + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_kv, + core_attention_bias, + max_seqlen_q, + max_seqlen_kv, + 1.0 / self.norm_factor, + qkv_dtype, + self.attention_dropout if self.training else 0.0, + set_zero, + self.qkv_layout, + core_attention_bias_type, + self.attn_mask_type, + self.training, + self.fused_attention_backend, + ) else: raise ValueError("attention_type must be one of ['self', 'cross']") return output @@ -495,7 +670,7 @@ class DotProductAttention(paddle.nn.Layer): ) out = paddle.matmul(attention_probs, v) - out = paddle.transpose(out, perm=[0, 2, 1, 3]) # [b, s, h, d] + out = paddle.transpose(out, perm=[0, 2, 1, 3]) # [b, s, h, d] # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) return out @@ -595,8 +770,8 @@ class MultiHeadAttention(paddle.nn.Layer): tp_group: Optional[dist_group_type] = None, num_gqa_groups: Optional[int] = None, fuse_wgrad_accumulation: bool = False, - rng_state_name: str = 'local_seed', - backend: str = 'transformer_engine', + rng_state_name: str = "local_seed", + backend: str = "transformer_engine", ) -> None: super().__init__() self.input_layernorm = input_layernorm @@ -610,8 +785,9 @@ class MultiHeadAttention(paddle.nn.Layer): assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" - self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, - enable_tp=set_parallel_mode) + self.tp_group, self.tp_size = get_tp_group_and_world_size( + tp_group, enable_tp=set_parallel_mode + ) self.tensor_parallel = self.tp_size > 1 self.sequence_parallel = self.tensor_parallel and sequence_parallel self.hidden_size_per_attention_head = hidden_size // num_attention_heads @@ -621,11 +797,13 @@ class MultiHeadAttention(paddle.nn.Layer): self.backend = backend self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) - self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups) - assert (self.num_attention_heads % self.num_gqa_groups == 0 - ), "The number of attention heads must be divisible by the number of GQA groups!" - assert (self.num_gqa_groups % self.tp_size == 0 - ), "The number of GQA groups must be divisible by tensor parallel size!" + self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups + assert ( + self.num_attention_heads % self.num_gqa_groups == 0 + ), "The number of attention heads must be divisible by the number of GQA groups!" + assert ( + self.num_gqa_groups % self.tp_size == 0 + ), "The number of GQA groups must be divisible by tensor parallel size!" self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads) qkv_parallel_mode = "column" if set_parallel_mode else None @@ -660,7 +838,7 @@ class MultiHeadAttention(paddle.nn.Layer): backend=self.backend, ) - else: # cross attention + else: # cross attention if self.input_layernorm: self.layernorm_query = LayerNormLinear( hidden_size, @@ -776,7 +954,7 @@ class MultiHeadAttention(paddle.nn.Layer): """ if self.attn_mask_type != "causal" and attention_mask is not None: - assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor" + assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor" input_dim = len(hidden_states.shape) if input_dim == 2: @@ -806,15 +984,20 @@ class MultiHeadAttention(paddle.nn.Layer): is_first_microbatch=is_first_microbatch, ) - num_queries_per_key_value = (self.num_attention_heads_per_partition // - self.num_gqa_groups_per_partition) + num_queries_per_key_value = ( + self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition + ) # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d] - mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[ - -1, max_seq_len, ( - num_queries_per_key_value + - 2), self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head - ]) + mixed_qkv_layer = mixed_qkv_layer.reshape( + shape=[ + -1, + max_seq_len, + (num_queries_per_key_value + 2), + self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + ] + ) # [b, s_q, (h/ng+2), ng, d] # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d] @@ -826,19 +1009,25 @@ class MultiHeadAttention(paddle.nn.Layer): # query: -> [b, s, h, d] # key, value: -> [b, s, ng, d] - query_layer, key_layer, value_layer = (x.reshape( - shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head]) - for x in (query_layer, key_layer, value_layer)) + query_layer, key_layer, value_layer = ( + x.reshape(shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head]) + for x in (query_layer, key_layer, value_layer) + ) - else: # cross attention + else: # cross attention mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, ) # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] - mixed_kv_layer = mixed_kv_layer.reshape(shape=[ - 0, 0, 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head - ]) + mixed_kv_layer = mixed_kv_layer.reshape( + shape=[ + 0, + 0, + 2 * self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + ] + ) # [b, s_kv, 2 * ng, head_size] # --> 2 [b, s_kv, ng, head_size] @@ -864,16 +1053,21 @@ class MultiHeadAttention(paddle.nn.Layer): ) # [b, s, hidden_size] --> [b, s, h, d] - query_layer = query_layer.reshape(shape=[ - -1, max_seq_len, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head - ]) + query_layer = query_layer.reshape( + shape=[ + -1, + max_seq_len, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ] + ) if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb if fused_rotary_position_embedding is None: - query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, q_pos_emb, - k_pos_emb) + query_layer, key_layer = apply_rotary_pos_emb( + query_layer, key_layer, q_pos_emb, k_pos_emb + ) else: query_layer, key_layer, _ = fused_rotary_position_embedding( query_layer, @@ -911,10 +1105,12 @@ class MultiHeadAttention(paddle.nn.Layer): if input_dim == 3: context_layer = paddle.reshape( - context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]]) - else: # input_dim == 2 - context_layer = paddle.reshape(context_layer, - [-1, context_layer.shape[2] * context_layer.shape[3]]) + context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]] + ) + else: # input_dim == 2 + context_layer = paddle.reshape( + context_layer, [-1, context_layer.shape[2] * context_layer.shape[3]] + ) # Output. [b, s, hidden] attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch) diff --git a/transformer_engine/paddle/layer/base.py b/transformer_engine/paddle/layer/base.py index 6d56322442711dfc42f9fd97a6b19545ddaf85b1..528ca52d101e09bf80b81e5b81a33d0e6700272c 100644 --- a/transformer_engine/paddle/layer/base.py +++ b/transformer_engine/paddle/layer/base.py @@ -12,6 +12,7 @@ from typing import Generator, Dict, Tuple, Union, Any, List, Optional import numpy as np import paddle + try: from paddle.base import core from paddle.base.framework import _dygraph_tracer @@ -52,7 +53,7 @@ def get_workspace() -> paddle.Tensor: if _cublas_workspace is None: _cublas_workspace = paddle.empty( [get_cublas_workspace_size_bytes()], - dtype='uint8', + dtype="uint8", ) return _cublas_workspace @@ -62,7 +63,7 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): def __init__(self) -> None: super().__init__() - assert 'gpu' in paddle.device.get_device(), "TransformerEngine needs CUDA." + assert "gpu" in paddle.device.get_device(), "TransformerEngine needs CUDA." self.fp8_initialized = False self.fp8_enabled = False self.fp8_calibration = False @@ -77,7 +78,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): self.sequence_parallel = False self.fp8_meta["autocast_id_fwd_stack"] = [] self.fp8_meta["async_amax_reduction"] = bool( - int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))) + int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) + ) self.fp8_weight_shapes = [] self.fp8_weight_cache = {} @@ -86,11 +88,11 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): tracer = _dygraph_tracer() if tracer and tracer._amp_level != core.AmpLevel.O0: # Set activation_dtype to the Paddle AMP dtype if under 'paddle.amp.auto_cast' context - if tracer._amp_dtype == 'float32': + if tracer._amp_dtype == "float32": self.activation_dtype = paddle.float32 - elif tracer._amp_dtype == 'bfloat16': + elif tracer._amp_dtype == "bfloat16": self.activation_dtype = paddle.bfloat16 - elif tracer._amp_dtype == 'float16': + elif tracer._amp_dtype == "float16": self.activation_dtype = paddle.float16 else: raise RuntimeError(f"AMP format {tracer._amp_dtype} is not supported.") @@ -110,7 +112,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): if param is not None: assert dtype == param.dtype, ( "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}") + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) self.activation_dtype = dtype @@ -125,8 +128,10 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): if self.fp8_enabled or self.fp8_calibration: # FP8 init has already been run and recipe is the same, don't do anything. - if self.fp8_initialized and global_fp8_state.get_fp8_recipe( - ) == self.fp8_meta["recipe"]: + if ( + self.fp8_initialized + and global_fp8_state.get_fp8_recipe() == self.fp8_meta["recipe"] + ): return # Set FP8, recipe, and other FP8 metadata @@ -156,8 +161,10 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): weight_cast_key = f"weight{i}_fp8" weight_transpose_key = f"weight{i}_t_fp8" - if (weight_cast_key in self.fp8_weight_cache - and self.fp8_weight_cache[weight_cast_key].shape == shape): + if ( + weight_cast_key in self.fp8_weight_cache + and self.fp8_weight_cache[weight_cast_key].shape == shape + ): return self.fp8_weight_cache[weight_cast_key] = paddle.empty( @@ -231,7 +238,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[ - 0] + 0 + ] recompute_buffer_pos_key = FP8RecomputeBuffer.get_buffer_position_key() if recompute_buffer_pos_key in self.fp8_meta: del self.fp8_meta[recompute_buffer_pos_key] @@ -271,9 +279,10 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): self.set_fp8_weights() if self.fp8_enabled and self.sequence_parallel: - assert self.fp8_meta["recipe"].reduce_amax, \ - "Amax reduction across tensor parallel group is " \ - "necessary when using sequence parallelism with FP8." + assert self.fp8_meta["recipe"].reduce_amax, ( + "Amax reduction across tensor parallel group is " + "necessary when using sequence parallelism with FP8." + ) update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch @@ -283,14 +292,14 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): global_fp8_fwd_buffer.wait() if self.fp8_meta["recipe"].reduce_amax: global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta) - amax_and_scale_update(self.fp8_meta, - True, - update_weight_scale_inv=update_weight_scale_inv) + amax_and_scale_update( + self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv + ) global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta) else: - amax_and_scale_update(self.fp8_meta, - True, - update_weight_scale_inv=update_weight_scale_inv) + amax_and_scale_update( + self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv + ) if self.fp8_enabled and self.training: # Setup for amax reduction @@ -304,8 +313,11 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): self.fp8_meta["update_amax_and_scale_fwd"] = False # Activation recomputation is used and this is the first forward phase. - if (self.fp8_enabled and self.training - and get_global_fp8_state().is_fp8_recompute_enabled()): + if ( + self.fp8_enabled + and self.training + and get_global_fp8_state().is_fp8_recompute_enabled() + ): global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer() global_recompute_buffer.stash_fp8_meta_tensors(self.fp8_meta) @@ -328,11 +340,13 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): @staticmethod @contextmanager - def prepare_backward(fp8_enabled: bool, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - name: str = "") -> Generator[None, None, None]: + def prepare_backward( + fp8_enabled: bool, + fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, + name: str = "", + ) -> Generator[None, None, None]: """Checks and prep for BWD.""" if fp8_enabled: global_fp8_state = get_global_fp8_state() @@ -358,8 +372,9 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size) @staticmethod - def grad_output_preprocess(ctx, grad_output: paddle.Tensor, - row_parallel_mode: bool) -> Tuple[Union[paddle.Tensor, None], ...]: + def grad_output_preprocess( + ctx, grad_output: paddle.Tensor, row_parallel_mode: bool + ) -> Tuple[Union[paddle.Tensor, None], ...]: """Utility function for backward. Returns tuple in order (all optional/None based on training precion/recipe): R1: gathered `grad_output` in higher precision. @@ -447,11 +462,14 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): weight_cast_key = f"weight{i}_fp8" weight_transpose_key = f"weight{i}_t_fp8" - assert weight_cast_key in self.fp8_weight_cache, \ - "TE internal error: fp8 weight buffer is not found" + assert ( + weight_cast_key in self.fp8_weight_cache + ), "TE internal error: fp8 weight buffer is not found" - out_list.extend([ - self.fp8_weight_cache[weight_cast_key], - self.fp8_weight_cache[weight_transpose_key], - ]) + out_list.extend( + [ + self.fp8_weight_cache[weight_cast_key], + self.fp8_weight_cache[weight_transpose_key], + ] + ) return out_list diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py index d158a4d1a2e198dbecf49599b9408a8c2cdb424f..208e39ea0332259d1d44049cd818b1504b75121b 100644 --- a/transformer_engine/paddle/layer/layernorm.py +++ b/transformer_engine/paddle/layer/layernorm.py @@ -36,8 +36,15 @@ class _LayerNorm(paddle.autograd.PyLayer): assert inp.shape[-1] == in_features, "LayerNorm not possible" inputmat = inp.reshape((-1, in_features)) - ln_out, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps, TE_DType[inp.dtype], - fwd_ln_sm_margin, zero_centered_gamma) + ln_out, mu, rsigma = layernorm_fwd( + inputmat, + ln_weight, + ln_bias, + eps, + TE_DType[inp.dtype], + fwd_ln_sm_margin, + zero_centered_gamma, + ) ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) ctx.inp_shape = inp.shape @@ -52,8 +59,9 @@ class _LayerNorm(paddle.autograd.PyLayer): def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: inputmat, ln_weight, mu, rsigma = ctx.saved_tensor() d_ln_out = grad_output.reshape(inputmat.shape) - dxmat, dgamma, dbeta = layernorm_bwd(d_ln_out, inputmat, mu, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma) + dxmat, dgamma, dbeta = layernorm_bwd( + d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ) return ( dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None, dgamma if ctx.requires_dw else None, @@ -106,7 +114,7 @@ class LayerNorm(paddle.nn.Layer): bias_attr: Union[paddle.ParamAttr, None, bool] = None, zero_centered_gamma: bool = False, sequence_parallel: bool = False, - backend: str = 'transformer_engine', + backend: str = "transformer_engine", ) -> None: super().__init__() self.eps = eps @@ -117,8 +125,9 @@ class LayerNorm(paddle.nn.Layer): self._weight_attr = weight_attr if not self._weight_attr: - self._weight_attr = paddle.ParamAttr(initializer=Constant( - value=0.0 if self.zero_centered_gamma else 1.0)) + self._weight_attr = paddle.ParamAttr( + initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) + ) self._bias_attr = bias_attr if self._bias_attr is False: @@ -151,8 +160,15 @@ class LayerNorm(paddle.nn.Layer): def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor: """LayerNorm FWD""" - return _LayerNorm.apply(inp, self.weight, self.bias, self.eps, self.fwd_ln_sm_margin, - self.bwd_ln_sm_margin, self.zero_centered_gamma) + return _LayerNorm.apply( + inp, + self.weight, + self.bias, + self.eps, + self.fwd_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + ) def _pd_forward( self, @@ -161,18 +177,21 @@ class LayerNorm(paddle.nn.Layer): """Calls Paddle OP""" if self.zero_centered_gamma: raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale.") - - return F.layer_norm(x=inp, - normalized_shape=inp.shape[-1], - weight=self.weight, - bias=self.bias, - epsilon=self.eps) + "Paddle backend does not support LayerNorm with zero-centered scale." + ) + + return F.layer_norm( + x=inp, + normalized_shape=inp.shape[-1], + weight=self.weight, + bias=self.bias, + epsilon=self.eps, + ) def forward(self, *args, **kwargs): """forward""" - if self.backend == 'transformer_engine': + if self.backend == "transformer_engine": return self._te_forward(*args, **kwargs) - if self.backend == 'paddle': + if self.backend == "paddle": return self._pd_forward(*args, **kwargs) raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py index 7a7b1e4ff946c736784d7e4ef7f8ae6d5d6785f0..a4b4cba9d481bac7ba5036f59e9c4126f9cb5b01 100644 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ b/transformer_engine/paddle/layer/layernorm_linear.py @@ -79,14 +79,14 @@ def _apply_normalization_fwd( } fwd_normalization_funcs = { - ('LayerNorm', True, True): layernorm_fwd, - ('LayerNorm', True, False): layernorm_fwd_fp8, - ('LayerNorm', False, True): layernorm_fwd, - ('LayerNorm', False, False): layernorm_fwd, - ('RMSNorm', True, True): rmsnorm_fwd, - ('RMSNorm', True, False): rmsnorm_fwd_fp8, - ('RMSNorm', False, True): rmsnorm_fwd, - ('RMSNorm', False, False): rmsnorm_fwd, + ("LayerNorm", True, True): layernorm_fwd, + ("LayerNorm", True, False): layernorm_fwd_fp8, + ("LayerNorm", False, True): layernorm_fwd, + ("LayerNorm", False, False): layernorm_fwd, + ("RMSNorm", True, True): rmsnorm_fwd, + ("RMSNorm", True, False): rmsnorm_fwd_fp8, + ("RMSNorm", False, True): rmsnorm_fwd, + ("RMSNorm", False, False): rmsnorm_fwd, } if normalization == "LayerNorm": @@ -107,7 +107,7 @@ def _apply_normalization_fwd( if normalization == "LayerNorm": norm_out_return, mu, rsigma = out_tuple - else: # RMSNorm + else: # RMSNorm norm_out_return, rsigma = out_tuple mu = None @@ -165,7 +165,7 @@ def _apply_normalization_bwd( out_tuple = norm_bwd_func(**norm_bwd_kwargs) if normalization == "LayerNorm": dxmat, dgamma, dbeta = out_tuple - else: # RMSNorm + else: # RMSNorm dxmat, dgamma = out_tuple dbeta = None @@ -207,7 +207,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: if normalization == "RMSNorm": assert ln_bias is None, "RMSNorm does not support bias!" - else: # LayerNorm + else: # LayerNorm assert ln_bias is not None, "LayerNorm requires bias!" # Make sure input dimensions are compatible in_features = ln_weight.shape[0] @@ -305,14 +305,12 @@ class _LayerNormLinear(paddle.autograd.PyLayer): @staticmethod def backward( - ctx, *grad_outputs: Tuple[paddle.Tensor, - ...]) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled, - ctx.fp8_meta, - ctx.tp_group, - ctx.tp_size, - name="_LayerNormLinear"): - ( # pylint: disable=unbalanced-tuple-unpacking + ctx, *grad_outputs: Tuple[paddle.Tensor, ...] + ) -> Tuple[Union[paddle.Tensor, None], ...]: + with TransformerEngineBaseLayer.prepare_backward( + ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear" + ): + ( # pylint: disable=unbalanced-tuple-unpacking inputmat, ln_weight, mu, @@ -328,12 +326,14 @@ class _LayerNormLinear(paddle.autograd.PyLayer): grad_output_c, grad_output_t, bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], - ctx.parallel_mode == "row") + ) = TransformerEngineBaseLayer.grad_output_preprocess( + ctx, grad_outputs[0], ctx.parallel_mode == "row" + ) if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation - and not ctx.is_first_microbatch) + accumulate_wgrad_into_param_main_grad = ( + ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation @@ -353,7 +353,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): # Linear Bwd dgrad, wgrad, bgrad_ = _linear_bwd( linear_inputmat, - None, # inputmat_t will be automatically computed if not provided + None, # inputmat_t will be automatically computed if not provided FP8FwdTensors.GEMM1_INPUT, weight, weight_t_fp8, @@ -366,7 +366,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ctx.requires_bgrad, ctx.fp8_enabled, ctx.fp8_meta, - True, # Always compute dgrad to feed into LayerNorm bwd + True, # Always compute dgrad to feed into LayerNorm bwd ctx.requires_wgrad, ctx.activation_dtype, ctx.parallel_mode, @@ -479,14 +479,14 @@ class LayerNormLinear(TransformerEngineBaseLayer): eps: float = 1e-5, weight_attr: Union[paddle.ParamAttr, None] = None, bias_attr: Union[paddle.ParamAttr, None, bool] = None, - normalization: str = 'LayerNorm', + normalization: str = "LayerNorm", return_layernorm_output: bool = False, zero_centered_gamma: bool = False, parallel_mode: Optional[str] = None, sequence_parallel: bool = False, tp_group: Union[dist_group_type, None] = None, fuse_wgrad_accumulation: bool = False, - backend: str = 'transformer_engine', + backend: str = "transformer_engine", ) -> None: super().__init__() @@ -494,7 +494,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): self.out_features = out_features self.eps = eps self.normalization = normalization - assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!" + assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" self.return_layernorm_output = return_layernorm_output self.zero_centered_gamma = zero_centered_gamma self.backend = backend @@ -504,13 +504,14 @@ class LayerNormLinear(TransformerEngineBaseLayer): self._dtype = self._helper.get_default_dtype() # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, - enable_tp=parallel_mode - is not None) + self.tp_group, self.tp_size = get_tp_group_and_world_size( + tp_group, enable_tp=parallel_mode is not None + ) self.tensor_parallel = self.tp_size > 1 self.parallel_mode = parallel_mode - assert (self.parallel_mode - in GemmParallelModes), f"parallel_mode {parallel_mode} not supported" + assert ( + self.parallel_mode in GemmParallelModes + ), f"parallel_mode {parallel_mode} not supported" if self.parallel_mode == "column": self.out_features = divide(self.out_features, self.tp_size) @@ -524,8 +525,9 @@ class LayerNormLinear(TransformerEngineBaseLayer): # LayerNorm weights self.ln_weight = self.create_parameter( shape=[self.in_features], - attr=paddle.ParamAttr(initializer=Constant( - value=0.0 if self.zero_centered_gamma else 1.0)), + attr=paddle.ParamAttr( + initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) + ), dtype=self._dtype, is_bias=False, ) @@ -548,14 +550,18 @@ class LayerNormLinear(TransformerEngineBaseLayer): with track_rng_state(enable=self.tensor_parallel): # TE linear weight is in column major self.weight = self.create_parameter( - shape=[self.out_features, self.in_features] - if self.backend == 'transformer_engine' else [self.in_features, self.out_features], + shape=( + [self.out_features, self.in_features] + if self.backend == "transformer_engine" + else [self.in_features, self.out_features] + ), attr=self._weight_attr, dtype=self._dtype, is_bias=False, ) - set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode, - self.backend) + set_weight_tensor_dist_attr( + self.weight, self.tensor_parallel, self.parallel_mode, self.backend + ) self.fp8_weight_shapes.append(self.weight.shape) # Initialize Linear bias parameter @@ -564,8 +570,11 @@ class LayerNormLinear(TransformerEngineBaseLayer): if self.has_bias: self.bias = self.create_parameter( shape=[self.out_features], - attr=self._bias_attr if not use_default_bias else paddle.ParamAttr( - initializer=Constant(value=0.0)), + attr=( + self._bias_attr + if not use_default_bias + else paddle.ParamAttr(initializer=Constant(value=0.0)) + ), dtype=self._dtype, is_bias=True, ) @@ -656,26 +665,30 @@ class LayerNormLinear(TransformerEngineBaseLayer): """Calls Paddle OP""" if self.zero_centered_gamma: raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale.") + "Paddle backend does not support LayerNorm with zero-centered scale." + ) if is_first_microbatch is not None: warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored.") + "`is_first_microbatch` is not supported for paddle backend and is ignored." + ) if self.normalization == "RMSNorm": norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) norm_out = inp * norm * self.ln_weight - else: # LayerNorm - norm_out = F.layer_norm(x=inp, - normalized_shape=inp.shape[-1], - weight=self.ln_weight, - bias=self.ln_bias, - epsilon=self.eps) - - if self.parallel_mode == 'column' and self.tensor_parallel: + else: # LayerNorm + norm_out = F.layer_norm( + x=inp, + normalized_shape=inp.shape[-1], + weight=self.ln_weight, + bias=self.ln_bias, + epsilon=self.eps, + ) + + if self.parallel_mode == "column" and self.tensor_parallel: norm_out = identity(norm_out, self.tp_group) out = F.linear(norm_out, self.weight, self.bias if self.gemm_bias_fused_add else None) - if self.parallel_mode == 'row' and self.tensor_parallel: + if self.parallel_mode == "row" and self.tensor_parallel: out, _ = allreduce(out, self.tp_group) out = out + self.bias if self.bias is not None else out if self.return_layernorm_output: @@ -701,8 +714,8 @@ class LayerNormLinear(TransformerEngineBaseLayer): * during FP8 training, it allows caching of the FP8 versions of the weights """ - if self.backend == 'transformer_engine': + if self.backend == "transformer_engine": return self._te_forward(*args, **kwargs) - if self.backend == 'paddle': + if self.backend == "paddle": return self._pd_forward(*args, **kwargs) raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py index 6443f1fd2ec02820cebc2fb1b4107e25fff3d0b3..d91b5283b1430523a82d709692c1a82e8747546c 100644 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ b/transformer_engine/paddle/layer/layernorm_mlp.py @@ -56,7 +56,7 @@ def _mlp_forward( fc1_weight_fp8_index: FP8FwdTensors, fc1_bias: Union[paddle.Tensor, None], use_fc1_bias: bool, - fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT + fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT fc2_weight: paddle.Tensor, fc2_weight_fp8: Optional[paddle.Tensor], fc2_weight_t_fp8: Optional[paddle.Tensor], @@ -88,7 +88,7 @@ def _mlp_forward( use_fc1_bias, fp8_meta, activation_dtype, - 'column' if set_parallel_mode else None, + "column" if set_parallel_mode else None, tensor_parallel, sequence_parallel, tp_group, @@ -123,7 +123,7 @@ def _mlp_forward( use_fc2_bias, fp8_meta, activation_dtype, - 'row' if set_parallel_mode else None, + "row" if set_parallel_mode else None, tensor_parallel, sequence_parallel, tp_group, @@ -141,7 +141,7 @@ def _mlp_forward( fp8_calibration, fp8_meta, activation_dtype, - 'column' if set_parallel_mode else None, + "column" if set_parallel_mode else None, tensor_parallel, sequence_parallel, tp_group, @@ -166,7 +166,7 @@ def _mlp_forward( fp8_calibration, fp8_meta, activation_dtype, - 'row' if set_parallel_mode else None, + "row" if set_parallel_mode else None, tensor_parallel, sequence_parallel, tp_group, @@ -181,17 +181,17 @@ def _mlp_forward( def _mlp_backward( - fc1_input: paddle.Tensor, # ln_out, BF16 / FP8 + fc1_input: paddle.Tensor, # ln_out, BF16 / FP8 fc1_input_fp8_index: FP8FwdTensors, fc1_weight: paddle.Tensor, fc1_weight_t_fp8: paddle.Tensor, fc1_weight_fp8_index: FP8FwdTensors, - fc1_grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT2 + fc1_grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT2 requires_fc1_wgrad: bool, requires_fc1_bgrad: bool, fc1_out: paddle.Tensor, - fc2_input: paddle.Tensor, # gelu_out - fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT + fc2_input: paddle.Tensor, # gelu_out + fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT fc2_weight: paddle.Tensor, fc2_weight_t_fp8: paddle.Tensor, fc2_weight_fp8_index: FP8FwdTensors, @@ -200,7 +200,7 @@ def _mlp_backward( grad_output: paddle.Tensor, grad_output_c: paddle.Tensor, grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT1 + grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT1 fwd_scale_inverses: paddle.Tensor, fp8_enabled: bool, fp8_meta: Dict[str, Any], @@ -220,7 +220,13 @@ def _mlp_backward( fc1_bgrad, fc2_wgrad, fc2_bgrad, - ) = None, None, None, None, None + ) = ( + None, + None, + None, + None, + None, + ) if fp8_enabled: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -252,7 +258,7 @@ def _mlp_backward( True, requires_fc2_wgrad, activation_dtype, - 'row' if set_parallel_mode else None, + "row" if set_parallel_mode else None, tensor_parallel, sequence_parallel, tp_group, @@ -316,7 +322,7 @@ def _mlp_backward( requires_dgrad, requires_fc1_wgrad, activation_dtype, - 'column' if set_parallel_mode else None, + "column" if set_parallel_mode else None, tensor_parallel, sequence_parallel, tp_group, @@ -332,7 +338,7 @@ def _mlp_backward( True, requires_fc2_wgrad, activation_dtype, - 'row' if set_parallel_mode else None, + "row" if set_parallel_mode else None, tensor_parallel, sequence_parallel, tp_group, @@ -353,7 +359,7 @@ def _mlp_backward( requires_dgrad, requires_fc1_wgrad, activation_dtype, - 'column' if set_parallel_mode else None, + "column" if set_parallel_mode else None, tensor_parallel, sequence_parallel, tp_group, @@ -410,7 +416,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: if normalization == "RMSNorm": assert ln_bias is None, "RMSNorm does not support bias!" - else: # LayerNorm + else: # LayerNorm assert ln_bias is not None, "LayerNorm requires bias!" # Make sure input dimensions are compatible in_features = ln_weight.shape[0] @@ -532,14 +538,12 @@ class _LayerNormMLP(paddle.autograd.PyLayer): @staticmethod def backward( - ctx, *grad_outputs: Tuple[paddle.Tensor, - ...]) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled, - ctx.fp8_meta, - ctx.tp_group, - ctx.tp_size, - name="_LayerNormMLP"): - ( # pylint: disable=unbalanced-tuple-unpacking + ctx, *grad_outputs: Tuple[paddle.Tensor, ...] + ) -> Tuple[Union[paddle.Tensor, None], ...]: + with TransformerEngineBaseLayer.prepare_backward( + ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP" + ): + ( # pylint: disable=unbalanced-tuple-unpacking inputmat, ln_weight, mu, @@ -554,7 +558,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): fwd_scale_inverses, ) = saved_tensor_allow_none(ctx) - ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess + ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ( grad_output, grad_output_c, @@ -563,8 +567,9 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True) if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation - and not ctx.is_first_microbatch) + accumulate_wgrad_into_param_main_grad = ( + ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation @@ -731,7 +736,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): sequence_parallel: bool = False, tp_group: Optional[dist_group_type] = None, fuse_wgrad_accumulation: bool = False, - backend: str = 'transformer_engine', + backend: str = "transformer_engine", ) -> None: super().__init__() @@ -750,8 +755,9 @@ class LayerNormMLP(TransformerEngineBaseLayer): self._dtype = self._helper.get_default_dtype() # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, - enable_tp=set_parallel_mode) + self.tp_group, self.tp_size = get_tp_group_and_world_size( + tp_group, enable_tp=set_parallel_mode + ) self.tensor_parallel = self.tp_size > 1 self.set_parallel_mode = set_parallel_mode self.sequence_parallel = self.tensor_parallel and sequence_parallel @@ -766,8 +772,9 @@ class LayerNormMLP(TransformerEngineBaseLayer): # LayerNorm weights self.ln_weight = self.create_parameter( shape=[self.hidden_size], - attr=paddle.ParamAttr(initializer=Constant( - value=0.0 if self.zero_centered_gamma else 1.0)), + attr=paddle.ParamAttr( + initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) + ), dtype=self._dtype, is_bias=False, ) @@ -795,16 +802,18 @@ class LayerNormMLP(TransformerEngineBaseLayer): with track_rng_state(enable=self.tensor_parallel): self.fc1_weight = self.create_parameter( - shape=[fc1_output_features, self.hidden_size] if self.backend - == 'transformer_engine' else [self.hidden_size, fc1_output_features], + shape=( + [fc1_output_features, self.hidden_size] + if self.backend == "transformer_engine" + else [self.hidden_size, fc1_output_features] + ), attr=self._weight_attr, dtype=self._dtype, is_bias=False, ) - set_weight_tensor_dist_attr(self.fc1_weight, - self.tensor_parallel, - parallel_mode='column', - backend=self.backend) + set_weight_tensor_dist_attr( + self.fc1_weight, self.tensor_parallel, parallel_mode="column", backend=self.backend + ) self.fp8_weight_shapes.append(self.fc1_weight.shape) self.has_bias = self._bias_attr is not False @@ -825,16 +834,18 @@ class LayerNormMLP(TransformerEngineBaseLayer): # FC2 weights self.fc2_weight = self.create_parameter( - shape=[self.hidden_size, self.size_per_partition] if self.backend - == 'transformer_engine' else [self.size_per_partition, self.hidden_size], + shape=( + [self.hidden_size, self.size_per_partition] + if self.backend == "transformer_engine" + else [self.size_per_partition, self.hidden_size] + ), attr=self._weight_attr, dtype=self._dtype, is_bias=False, ) - set_weight_tensor_dist_attr(self.fc2_weight, - self.tensor_parallel, - parallel_mode='row', - backend=self.backend) + set_weight_tensor_dist_attr( + self.fc2_weight, self.tensor_parallel, parallel_mode="row", backend=self.backend + ) self.fp8_weight_shapes.append(self.fc2_weight.shape) if self.has_bias: @@ -880,8 +891,9 @@ class LayerNormMLP(TransformerEngineBaseLayer): inp = cast_if_needed(inp, self.activation_dtype) # Get persistent fp8 weight buffer. None if buffer does not exist. - fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = \ - self.get_fp8_weights_scratchpad(is_first_microbatch) + fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = ( + self.get_fp8_weights_scratchpad(is_first_microbatch) + ) out = _LayerNormMLP.apply( inp, @@ -936,28 +948,33 @@ class LayerNormMLP(TransformerEngineBaseLayer): """Calls Paddle OP""" if self.zero_centered_gamma: raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale.") + "Paddle backend does not support LayerNorm with zero-centered scale." + ) if is_first_microbatch is not None: warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored.") + "`is_first_microbatch` is not supported for paddle backend and is ignored." + ) if self.normalization == "RMSNorm": norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) norm_out = inp * norm * self.ln_weight - else: # LayerNorm - norm_out = F.layer_norm(x=inp, - normalized_shape=inp.shape[-1], - weight=self.ln_weight, - bias=self.ln_bias, - epsilon=self.eps) + else: # LayerNorm + norm_out = F.layer_norm( + x=inp, + normalized_shape=inp.shape[-1], + weight=self.ln_weight, + bias=self.ln_bias, + epsilon=self.eps, + ) if self.set_parallel_mode and self.tensor_parallel: norm_out = identity(norm_out, self.tp_group) fc1_out = F.linear(norm_out, self.fc1_weight, self.fc1_bias) act_func = get_paddle_act_func(self.activation) act_out = act_func(fc1_out) - out = F.linear(act_out, self.fc2_weight, - self.fc2_bias if self.gemm_bias_fused_add else None) + out = F.linear( + act_out, self.fc2_weight, self.fc2_bias if self.gemm_bias_fused_add else None + ) if self.set_parallel_mode and self.tensor_parallel: out, _ = allreduce(out, self.tp_group) out = out + self.fc2_bias if self.fc2_bias is not None else out @@ -984,8 +1001,8 @@ class LayerNormMLP(TransformerEngineBaseLayer): * during FP8 training, it allows caching of the FP8 versions of the weights """ - if self.backend == 'transformer_engine': + if self.backend == "transformer_engine": return self._te_forward(*args, **kwargs) - if self.backend == 'paddle': + if self.backend == "paddle": return self._pd_forward(*args, **kwargs) raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/linear.py b/transformer_engine/paddle/layer/linear.py index 74d67a4545ba9e2375818cc67d2dd092c18d8484..d471d0363f5920a2a68a672d2ddf764359b72dba 100644 --- a/transformer_engine/paddle/layer/linear.py +++ b/transformer_engine/paddle/layer/linear.py @@ -152,22 +152,26 @@ def _linear_fwd_non_fp8( if fp8_calibration: # amax of input - fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \ - paddle.max(paddle.abs(inputmat_total)).item() + fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = paddle.max( + paddle.abs(inputmat_total) + ).item() # amax of weight - fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \ - paddle.max(paddle.abs(weight)).item() + fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = paddle.max( + paddle.abs(weight) + ).item() fp8_meta["update_amax_and_scale_fwd"] = True - outputs = gemm(weight, - inputmat_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - gelu=(activation == 'gelu')) + outputs = gemm( + weight, + inputmat_total, + activation_dtype, + get_workspace(), + bias=bias, + use_bias=use_bias, + gelu=(activation == "gelu"), + ) - if activation == 'gelu': + if activation == "gelu": gelu_out, _, out = outputs return out, gelu_out @@ -382,7 +386,7 @@ def _linear_bwd_non_fp8( activation_dtype, get_workspace(), layout="NN", - gelu=(activation == 'gelu'), + gelu=(activation == "gelu"), gelu_input=gelu_input, grad=True, ) @@ -527,8 +531,11 @@ class _Linear(paddle.autograd.PyLayer): inputmat_t = None if fp8_enabled: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if (not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled - and not sequence_parallel): + if ( + not fp8_meta["recipe"].override_linear_precision.wgrad + and is_grad_enabled + and not sequence_parallel + ): inputmat, inputmat_t = cast_transpose( inputmat, fp8_meta["scaling_fwd"], @@ -599,13 +606,11 @@ class _Linear(paddle.autograd.PyLayer): @staticmethod def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled, - ctx.fp8_meta, - ctx.tp_group, - ctx.tp_size, - name="_Linear"): + with TransformerEngineBaseLayer.prepare_backward( + ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear" + ): - ( # pylint: disable=unbalanced-tuple-unpacking + ( # pylint: disable=unbalanced-tuple-unpacking inputmat, inputmat_t, weight, @@ -618,11 +623,13 @@ class _Linear(paddle.autograd.PyLayer): grad_output_c, grad_output_t, bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output, - ctx.parallel_mode == "row") + ) = TransformerEngineBaseLayer.grad_output_preprocess( + ctx, grad_output, ctx.parallel_mode == "row" + ) if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation - and not ctx.is_first_microbatch) + accumulate_wgrad_into_param_main_grad = ( + ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation @@ -730,7 +737,7 @@ class Linear(TransformerEngineBaseLayer): sequence_parallel: bool = False, tp_group: Union[dist_group_type, None] = None, fuse_wgrad_accumulation: bool = False, - backend: str = 'transformer_engine', + backend: str = "transformer_engine", ) -> None: super().__init__() self.in_features = in_features @@ -741,13 +748,14 @@ class Linear(TransformerEngineBaseLayer): self._dtype = self._helper.get_default_dtype() # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, - enable_tp=parallel_mode - is not None) + self.tp_group, self.tp_size = get_tp_group_and_world_size( + tp_group, enable_tp=parallel_mode is not None + ) self.tensor_parallel = self.tp_size > 1 self.parallel_mode = parallel_mode - assert (self.parallel_mode - in GemmParallelModes), f"parallel_mode {parallel_mode} not supported" + assert ( + self.parallel_mode in GemmParallelModes + ), f"parallel_mode {parallel_mode} not supported" if self.parallel_mode == "column": self.out_features = divide(self.out_features, self.tp_size) @@ -762,14 +770,18 @@ class Linear(TransformerEngineBaseLayer): with track_rng_state(enable=self.tensor_parallel): # TE linear weight is in column major self.weight = self.create_parameter( - shape=[self.out_features, self.in_features] - if self.backend == 'transformer_engine' else [self.in_features, self.out_features], + shape=( + [self.out_features, self.in_features] + if self.backend == "transformer_engine" + else [self.in_features, self.out_features] + ), attr=self._weight_attr, dtype=self._dtype, is_bias=False, ) - set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode, - self.backend) + set_weight_tensor_dist_attr( + self.weight, self.tensor_parallel, self.parallel_mode, self.backend + ) # Initialize bias parameter self.has_bias = self._bias_attr is not False @@ -777,8 +789,11 @@ class Linear(TransformerEngineBaseLayer): if self.has_bias: self.bias = self.create_parameter( shape=[self.out_features], - attr=self._bias_attr if not use_default_bias else paddle.ParamAttr( - initializer=Constant(value=0.0)), + attr=( + self._bias_attr + if not use_default_bias + else paddle.ParamAttr(initializer=Constant(value=0.0)) + ), dtype=self._dtype, is_bias=True, ) @@ -849,11 +864,12 @@ class Linear(TransformerEngineBaseLayer): """Calls Paddle OP""" if is_first_microbatch is not None: warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored.") - if self.parallel_mode == 'column' and self.tensor_parallel: + "`is_first_microbatch` is not supported for paddle backend and is ignored." + ) + if self.parallel_mode == "column" and self.tensor_parallel: inp = identity(inp, self.tp_group) out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None) - if self.parallel_mode == 'row' and self.tensor_parallel: + if self.parallel_mode == "row" and self.tensor_parallel: out, _ = allreduce(out, self.tp_group) out = out + self.bias if self.bias is not None else out return out @@ -877,8 +893,8 @@ class Linear(TransformerEngineBaseLayer): * during FP8 training, it allows caching of the FP8 versions of the weights """ - if self.backend == 'transformer_engine': + if self.backend == "transformer_engine": return self._te_forward(*args, **kwargs) - if self.backend == 'paddle': + if self.backend == "paddle": return self._pd_forward(*args, **kwargs) raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/rmsnorm.py b/transformer_engine/paddle/layer/rmsnorm.py index b88e08ff5b7911b53a8ad933a93cddb088898028..1afc3d975979b8182dba470808bf03a22ed35a45 100644 --- a/transformer_engine/paddle/layer/rmsnorm.py +++ b/transformer_engine/paddle/layer/rmsnorm.py @@ -33,8 +33,14 @@ class _RMSNorm(paddle.autograd.PyLayer): assert inp.shape[-1] == in_features, "RMSNorm not possible" inputmat = inp.reshape((-1, in_features)) - rmsnorm_out, rsigma = rmsnorm_fwd(inputmat, rmsnorm_weight, eps, TE_DType[inp.dtype], - fwd_rmsnorm_sm_margin, zero_centered_gamma) + rmsnorm_out, rsigma = rmsnorm_fwd( + inputmat, + rmsnorm_weight, + eps, + TE_DType[inp.dtype], + fwd_rmsnorm_sm_margin, + zero_centered_gamma, + ) ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma) ctx.inp_shape = inp.shape @@ -49,8 +55,14 @@ class _RMSNorm(paddle.autograd.PyLayer): def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: inputmat, rmsnorm_weight, rsigma = ctx.saved_tensor() d_rmsnorm_out = grad_output.reshape(inputmat.shape) - dxmat, dgamma = rmsnorm_bwd(d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight, - ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma) + dxmat, dgamma = rmsnorm_bwd( + d_rmsnorm_out, + inputmat, + rsigma, + rmsnorm_weight, + ctx.bwd_rmsnorm_sm_margin, + ctx.zero_centered_gamma, + ) return ( dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None, dgamma if ctx.requires_dw else None, @@ -149,7 +161,8 @@ class RMSNorm(paddle.nn.Layer): ) -> paddle.Tensor: if self.zero_centered_gamma: raise NotImplementedError( - "Paddle backend does not support RMSNorm with zero_centered_gamma.") + "Paddle backend does not support RMSNorm with zero_centered_gamma." + ) norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) y = inp * norm * self.weight return y diff --git a/transformer_engine/paddle/layer/softmax.py b/transformer_engine/paddle/layer/softmax.py index b195f0305ffa3fa58758b584b5fda2e47b016b43..11549364febde555597390e12fa972e3c725f10f 100644 --- a/transformer_engine/paddle/layer/softmax.py +++ b/transformer_engine/paddle/layer/softmax.py @@ -32,8 +32,9 @@ _default_causal_mask = {} def _get_default_causal_mask(seqlen: int) -> paddle.Tensor: """Return the causal upper triangular mask for softmax input""" if seqlen not in _default_causal_mask: - _default_causal_mask[seqlen] = paddle.triu(paddle.ones((seqlen, seqlen)), - diagonal=1).cast('bool') + _default_causal_mask[seqlen] = paddle.triu(paddle.ones((seqlen, seqlen)), diagonal=1).cast( + "bool" + ) return _default_causal_mask[seqlen] @@ -58,8 +59,9 @@ class ScaledUpperTriangMaskedSoftmax(paddle.autograd.PyLayer): def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: """ScaledUpperTriangMaskedSoftmax bwd""" softmax_results, scale_t = ctx.saved_tensor() - input_grads = scaled_upper_triang_masked_softmax_backward(output_grads, softmax_results, - scale_t[0]) + input_grads = scaled_upper_triang_masked_softmax_backward( + output_grads, softmax_results, scale_t[0] + ) return input_grads, None @@ -140,7 +142,7 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer): attn_mask_type: str, mask_func: Callable, softmax_in_fp32: bool = True, - backend: str = 'transformer_engine', + backend: str = "transformer_engine", ) -> None: super().__init__() self.attn_mask_type = attn_mask_type @@ -162,16 +164,17 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer): self.input_is_bf16 = inp.dtype == paddle.bfloat16 self.input_in_16bit_float = self.input_is_fp16 or self.input_is_bf16 - assert (scale is None or self.softmax_in_fp32), "softmax should be in fp32 when scaled" + assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" - if self.backend == 'transformer_engine' and not self.is_kernel_available(*inp.shape): + if self.backend == "transformer_engine" and not self.is_kernel_available(*inp.shape): warnings.warn( - "fused kernel is not available for this input shape, fall back to paddle backend") - self.backend = 'paddle' + "fused kernel is not available for this input shape, fall back to paddle backend" + ) + self.backend = "paddle" - if self.backend == 'transformer_engine': + if self.backend == "transformer_engine": return self._te_forward(inp, mask, scale) - if self.backend == 'paddle': + if self.backend == "paddle": return self._pd_forward(inp, mask, scale) raise AttributeError(f"Backend {self.backend} is not supported.") @@ -179,12 +182,13 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer): """Check FusedScaleMaskSoftmax kernel availability based on size""" attn_batches = b * h - if (self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_16bit_float # input must be fp16 - and 16 < s_kv <= 4096 # s_kv must be 16 ~ 2048 - and s_q % 4 == 0 # s_q must be a multiple of 4 - and attn_batches % 4 == 0 # b * h must be a multiple of 4 - ): + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_16bit_float # input must be fp16 + and 16 < s_kv <= 4096 # s_kv must be 16 ~ 2048 + and s_q % 4 == 0 # s_q must be a multiple of 4 + and attn_batches % 4 == 0 # b * h must be a multiple of 4 + ): if 0 <= s_kv <= 4096: batch_per_block = self.get_batch_per_block(int(s_kv)) @@ -196,10 +200,9 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer): return True return False - def _te_forward(self, - inp: paddle.Tensor, - mask: paddle.Tensor, - scale: Optional[float] = None) -> paddle.Tensor: + def _te_forward( + self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None + ) -> paddle.Tensor: """Fused masked softmax kernel""" b, h, s_q, s_kv = inp.size() scale = 1.0 if scale is None else scale @@ -216,13 +219,12 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer): return ScaledMaskedSoftmax.apply(inp, mask, scale) return ScaledSoftmax.apply(inp, scale) - def _pd_forward(self, - inp: paddle.Tensor, - mask: paddle.Tensor, - scale: Optional[float] = None) -> paddle.Tensor: + def _pd_forward( + self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None + ) -> paddle.Tensor: """Call Paddle OP""" if self.input_in_16bit_float and self.softmax_in_fp32: - inp = paddle.cast(inp, 'float32') + inp = paddle.cast(inp, "float32") if scale is not None: inp = inp * scale @@ -235,9 +237,9 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer): if self.input_in_16bit_float and self.softmax_in_fp32: if self.input_is_fp16: - probs = paddle.cast(probs, 'float16') + probs = paddle.cast(probs, "float16") else: - probs = paddle.cast(probs, 'bfloat16') + probs = paddle.cast(probs, "bfloat16") return probs diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py index 61ba17e234d4b675ca6dcf949595cdeb8a398529..c2835a3160316da102ad5787d7a4a0f8dfc8a2e5 100644 --- a/transformer_engine/paddle/layer/transformer.py +++ b/transformer_engine/paddle/layer/transformer.py @@ -112,32 +112,34 @@ class TransformerLayer(paddle.nn.Layer): """ - def __init__(self, - hidden_size: int, - ffn_hidden_size: int, - num_attention_heads: int, - num_gqa_groups: Optional[int] = None, - layernorm_epsilon: float = 1e-5, - hidden_dropout: float = 0.1, - attention_dropout: float = 0.1, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - max_sequence_length: Optional[int] = None, - self_attn_mask_type: str = "causal", - params_dtype: Optional[paddle.dtype] = None, - apply_residual_connection_post_layernorm: bool = False, - output_layernorm: bool = False, - layer_type: str = "encoder", - normalization: str = "LayerNorm", - zero_centered_gamma: bool = False, - activation: str = 'gelu', - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - fuse_wgrad_accumulation: bool = False, - attention_dropout_rng_state_name: str = 'local_seed', - hidden_dropout_rng_state_name: str = 'global_seed', - backend: str = 'transformer_engine') -> None: + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + num_gqa_groups: Optional[int] = None, + layernorm_epsilon: float = 1e-5, + hidden_dropout: float = 0.1, + attention_dropout: float = 0.1, + weight_attr: Union[paddle.ParamAttr, None] = None, + bias_attr: Union[paddle.ParamAttr, None, bool] = None, + max_sequence_length: Optional[int] = None, + self_attn_mask_type: str = "causal", + params_dtype: Optional[paddle.dtype] = None, + apply_residual_connection_post_layernorm: bool = False, + output_layernorm: bool = False, + layer_type: str = "encoder", + normalization: str = "LayerNorm", + zero_centered_gamma: bool = False, + activation: str = "gelu", + set_parallel_mode: bool = False, + sequence_parallel: bool = False, + tp_group: Optional[dist_group_type] = None, + fuse_wgrad_accumulation: bool = False, + attention_dropout_rng_state_name: str = "local_seed", + hidden_dropout_rng_state_name: str = "global_seed", + backend: str = "transformer_engine", + ) -> None: super().__init__() params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype @@ -146,19 +148,23 @@ class TransformerLayer(paddle.nn.Layer): self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.self_attn_mask_type = self_attn_mask_type self.set_parallel_mode = set_parallel_mode - self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, - enable_tp=set_parallel_mode) + self.tp_group, self.tp_size = get_tp_group_and_world_size( + tp_group, enable_tp=set_parallel_mode + ) self.tensor_parallel = self.tp_size > 1 self.sequence_parallel = self.tensor_parallel and sequence_parallel self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name # SP needs local seed for hidden dropout - if self.sequence_parallel and self.hidden_dropout_rng_state_name == 'global_seed': - warnings.warn("RNG state for hidden dropout needs to be different across TP ranks. " - "Forcing hidden_dropout_rng_state_name to 'local_seed'") - self.hidden_dropout_rng_state_name = 'local_seed' + if self.sequence_parallel and self.hidden_dropout_rng_state_name == "global_seed": + warnings.warn( + "RNG state for hidden dropout needs to be different across TP ranks. " + "Forcing hidden_dropout_rng_state_name to 'local_seed'" + ) + self.hidden_dropout_rng_state_name = "local_seed" - assert (self_attn_mask_type - in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported" + assert ( + self_attn_mask_type in AttnMaskTypes + ), f"self_attn_mask_type {self_attn_mask_type} not supported" assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" attention_args = ( @@ -176,7 +182,7 @@ class TransformerLayer(paddle.nn.Layer): "zero_centered_gamma": zero_centered_gamma, "set_parallel_mode": set_parallel_mode, "sequence_parallel": self.sequence_parallel, - 'max_sequence_length': max_sequence_length, + "max_sequence_length": max_sequence_length, "tp_group": tp_group, "num_gqa_groups": num_gqa_groups, "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -295,10 +301,12 @@ class TransformerLayer(paddle.nn.Layer): """ if self.self_attn_mask_type != "causal" and attention_mask is not None: - assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor" + assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor" - assert core_attention_bias_type in ['no_bias'], f"Only no_bias is supported currently, " \ + assert core_attention_bias_type in ["no_bias"], ( + "Only no_bias is supported currently, " f"but receive core_attention_bias_type = {core_attention_bias_type}" + ) # Self attention. self_attention_outputs = self.self_attention( @@ -340,8 +348,9 @@ class TransformerLayer(paddle.nn.Layer): attention_output = inter_attention_outputs residual = bda_output - with track_rng_state(enable=self.tensor_parallel, - name=self.hidden_dropout_rng_state_name): + with track_rng_state( + enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name + ): bda_output = self.fused_dropout_add2(attention_output, residual) # MLP. diff --git a/transformer_engine/paddle/recompute.py b/transformer_engine/paddle/recompute.py index 352d8a2df8d0aac18ac11687c55ac9fe24ca564e..1d64ad0de0de11eef3aea5279fbed988b8ce4b36 100644 --- a/transformer_engine/paddle/recompute.py +++ b/transformer_engine/paddle/recompute.py @@ -12,7 +12,7 @@ from .constants import RecomputeFunctionNames from .fp8 import get_global_fp8_state -__all__ = ['recompute'] +__all__ = ["recompute"] _DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0")) @@ -48,8 +48,9 @@ def recompute(function, *args, **kwargs): kwargs : dict dictionary of string keys for keyword arguments to :attr:`function`. """ - assert not _DISABLE_RECOMPUTE, "Recompute is disabled. " \ - f"Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}." + assert ( + not _DISABLE_RECOMPUTE + ), f"Recompute is disabled. Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}." global_fp8_state = get_global_fp8_state() diff --git a/transformer_engine/paddle/setup.py b/transformer_engine/paddle/setup.py index 97819a0bb8b3b32cd45705de3c63c34e099f0104..8d30d88b7d98313adb42f712d6ae6dbea14a00dc 100644 --- a/transformer_engine/paddle/setup.py +++ b/transformer_engine/paddle/setup.py @@ -27,7 +27,10 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_ from build_tools.build_ext import get_build_ext # pylint: disable=wrong-import-position -from build_tools.utils import package_files, copy_common_headers # pylint: disable=wrong-import-position +from build_tools.utils import ( + package_files, + copy_common_headers, +) # pylint: disable=wrong-import-position from build_tools.te_version import te_version # pylint: disable=wrong-import-position from build_tools.paddle import setup_paddle_extension # pylint: disable=wrong-import-position @@ -38,12 +41,12 @@ CMakeBuildExtension = get_build_ext(BuildExtension) if __name__ == "__main__": # Extensions common_headers_dir = "common_headers" - copy_common_headers( - current_file_path.parent, - str(current_file_path / common_headers_dir)) + copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) ext_modules = [ setup_paddle_extension( - "csrc", current_file_path / "csrc", current_file_path / common_headers_dir)] + "csrc", current_file_path / "csrc", current_file_path / common_headers_dir + ) + ] # Configure package setuptools.setup( @@ -56,9 +59,11 @@ if __name__ == "__main__": install_requires=["paddlepaddle-gpu"], tests_require=["numpy"], include_package_data=True, - package_data={"csrc": package_files("csrc"), - common_headers_dir: package_files(common_headers_dir), - "build_tools": package_files("build_tools")}, + package_data={ + "csrc": package_files("csrc"), + common_headers_dir: package_files(common_headers_dir), + "build_tools": package_files("build_tools"), + }, ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/paddle/utils.py b/transformer_engine/paddle/utils.py index 3187b55222568ff694991161d7c880d6a66f13be..7b9aabbf5a5273a88239f98c42be03d829b8b8eb 100644 --- a/transformer_engine/paddle/utils.py +++ b/transformer_engine/paddle/utils.py @@ -10,33 +10,36 @@ import paddle.nn.functional as F from .cpp_extensions import swiglu_pd -def cast_if_needed(tensor: Union[paddle.Tensor, None], - dtype: paddle.dtype) -> Union[paddle.Tensor, None]: +def cast_if_needed( + tensor: Union[paddle.Tensor, None], dtype: paddle.dtype +) -> Union[paddle.Tensor, None]: """Cast tensor to dtype""" return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype) -def cast_if_needed_inplace(tensor: Union[paddle.Tensor, None], - dtype: paddle.dtype) -> Union[paddle.Tensor, None]: +def cast_if_needed_inplace( + tensor: Union[paddle.Tensor, None], dtype: paddle.dtype +) -> Union[paddle.Tensor, None]: """Cast tensor to dtype (inplace), not to be used on layer inputs""" return tensor if tensor is None or tensor.dtype == dtype else tensor._to(dtype=dtype) def check_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> bool: """For fp8 fprop (TN layout), inputs and weights must be such - that dim0 is divisible by 8 and dim1 is divisible by 16. + that dim0 is divisible by 8 and dim1 is divisible by 16. """ return not tensor.shape[0] % 8 and not tensor.shape[1] % 16 def assert_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> None: """For fp8 fprop (TN layout), inputs and weights must be such - that dim0 is divisible by 8 and dim1 is divisible by 16. + that dim0 is divisible by 8 and dim1 is divisible by 16. """ # single tensor check so it's clear which tensor is triggering the assertion assert check_dim_for_fp8_forward_exec(tensor), ( "Tensor dimensions are not compatible for FP8 execution: " - f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)") + f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)" + ) def get_bias_dtype(activation_dtype: paddle.dtype): @@ -47,18 +50,19 @@ def get_bias_dtype(activation_dtype: paddle.dtype): def get_paddle_act_func(activation): """Get paddle activation function""" funcs = { - 'gelu': F.gelu, - 'relu': F.relu, - 'silu': F.silu, - 'swiglu': swiglu_pd, + "gelu": F.gelu, + "relu": F.relu, + "silu": F.silu, + "swiglu": swiglu_pd, } if activation not in funcs: raise "Activation type " + activation + " is not supported." return funcs[activation] -def attention_mask_func(attention_scores: paddle.Tensor, - attention_mask: paddle.Tensor) -> paddle.Tensor: +def attention_mask_func( + attention_scores: paddle.Tensor, attention_mask: paddle.Tensor +) -> paddle.Tensor: """Get attention mask""" def _masked_fill(x, mask, value): @@ -71,14 +75,14 @@ def attention_mask_func(attention_scores: paddle.Tensor, def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Tensor: """Convert mask to cu_seqlens""" - assert 'bool' in str(mask.dtype), "mask must be bool dtype" + assert "bool" in str(mask.dtype), "mask must be bool dtype" assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]" - q_actual_seqlens = paddle.sum(mask[:, :, :, 0].logical_not(), axis=(-1, -2), dtype='int32') + q_actual_seqlens = paddle.sum(mask[:, :, :, 0].logical_not(), axis=(-1, -2), dtype="int32") q_cu_seqlens = paddle.cumsum(q_actual_seqlens) q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0) if not need_kv: return q_cu_seqlens, None - kv_actual_seqlens = paddle.sum(mask[:, :, 0, :].logical_not(), axis=(-1, -2), dtype='int32') + kv_actual_seqlens = paddle.sum(mask[:, :, 0, :].logical_not(), axis=(-1, -2), dtype="int32") kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens) kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0) return q_cu_seqlens, kv_cu_seqlens @@ -87,7 +91,7 @@ def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Ten def divide(numerator: int, denominator: int) -> int: """Ensure that numerator is divisible by the denominator and return the division value.""" - assert (numerator % denominator == 0), f"{numerator} is not divisible by {denominator}" + assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" return numerator // denominator @@ -110,8 +114,9 @@ def save_for_backward_allow_none(ctx, *args) -> None: def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]: """Used with `save_for_backward_allow_none` in pair. Get saved tensors from ctx.""" - assert hasattr(ctx, '_indices_mapping'), "`saved_tensor_allow_none` must be used " \ - "with `save_for_backward_allow_none` in pair." + assert hasattr( + ctx, "_indices_mapping" + ), "`saved_tensor_allow_none` must be used with `save_for_backward_allow_none` in pair." indices_mapping = ctx._indices_mapping outputs = [] @@ -132,8 +137,12 @@ def clear_tensor_data(*tensors: Tuple[Optional[paddle.Tensor], ...]) -> None: """ def can_free(t): - return (t is not None and isinstance(t, paddle.Tensor) and t._is_initialized() - and t.inplace_version == 0) + return ( + t is not None + and isinstance(t, paddle.Tensor) + and t._is_initialized() + and t.inplace_version == 0 + ) for t in tensors: if can_free(t): diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 126c22c656bc8da13cb542f406c327cfd75960e5..2c2d1ed1a0cb3fa48eb598cd486cafa09543cd9f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -84,12 +84,12 @@ if _flash_attn_version >= _flash_attn_version_required: from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT +META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +META_O = tex.FP8FwdTensors.GEMM2_INPUT +META_DO = tex.FP8BwdTensors.GRAD_INPUT2 +META_S = tex.FP8FwdTensors.GEMM3_OUTPUT +META_DP = tex.FP8BwdTensors.GRAD_INPUT3 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) @@ -98,8 +98,8 @@ _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} logging.basicConfig( - format='[%(levelname)-8s | %(name)-19s]: %(message)s', - level=log_levels[log_level if log_level in [0,1,2] else 2], + format="[%(levelname)-8s | %(name)-19s]: %(message)s", + level=log_levels[log_level if log_level in [0, 1, 2] else 2], ) _alibi_cache = { @@ -110,13 +110,13 @@ _alibi_cache = { "_alibi_bias": None, "_alibi_slopes_require_update": False, "_alibi_bias_require_update": False, - } +} __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] -class InferenceParams: # pylint: disable=too-few-public-methods +class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference. @@ -161,6 +161,7 @@ class InferenceParams: # pylint: disable=too-few-public-methods new_inference_value_memory, ) + @torch.no_grad() def get_alibi( num_heads: int, @@ -217,10 +218,12 @@ def get_alibi( slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) if _alibi_cache["_alibi_slopes"].dim() == 2: slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) - bias = torch.arange( - 1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) - bias = bias - torch.arange( - 1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(1, 1, max_seqlen_q, 1) + bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv + ) + bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) bias = bias.abs().mul(-1) bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv @@ -267,8 +270,9 @@ def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch. num_nonzeros = indices.shape[0] pad_amount = bs * seqlen - num_nonzeros - indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount), - mode="constant", value=float(bs * seqlen)) + indices = F.pad( + input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen) + ) return cu_seqlens, indices @@ -281,18 +285,24 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: """ bs = len(cu_seqlens) - 1 seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - indices = [i*max_seqlen + ii for i,j in enumerate(seqlens) for ii in range(j)] - indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to( - dtype=torch.int64, device="cuda") + indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)] + indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda") num_nonzeros = indices.shape[0] pad_amount = bs * max_seqlen - num_nonzeros - indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount), - mode="constant", value=float(bs * max_seqlen)) + indices = F.pad( + input=indices, + pad=(0, 0, 0, 0, 0, pad_amount), + mode="constant", + value=float(bs * max_seqlen), + ) return indices + _cu_seqlens_cache = {} + + def _get_full_cu_seqlens( batch_size: int, max_seqlen: int, @@ -324,7 +334,8 @@ def pack_tensor( Packs the given tensor using the `indices`. """ padding_indice = torch.zeros( - 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device) + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device + ) tensor = torch.cat((tensor, padding_indice), dim=0) indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) @@ -373,9 +384,10 @@ def unpack_tensor( """ indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) unpacked = torch.zeros( - dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device) + dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device + ) unpacked.scatter_(0, indices, tensor) - unpacked = unpacked[0:-1,:,:] + unpacked = unpacked[0:-1, :, :] return unpacked @@ -415,11 +427,10 @@ class PackTensors(torch.autograd.Function): """ Autograd function to pack tensors. """ + @staticmethod def forward( - ctx, - indices: torch.Tensor, - *tensors: Tuple[torch.Tensor, ...] + ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...] ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported." ctx.save_for_backward(indices) @@ -444,6 +455,7 @@ class UnpackTensor(torch.autograd.Function): """ Autograd function to unpack a tensor. """ + @staticmethod def forward( ctx, @@ -460,33 +472,29 @@ class UnpackTensor(torch.autograd.Function): return None, None, pack_tensor(indices, grad_output) -def flash_attn_p2p_communicate(rank, send_tensor, send_dst, - recv_tensor, recv_src, - cp_group, batch_p2p_comm): +def flash_attn_p2p_communicate( + rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm +): """Point-to-point communications of KV and dKV in Attention with context parallelism""" send_recv_ops = [] if batch_p2p_comm: if rank % 2 == 0: - send_op = torch.distributed.P2POp(torch.distributed.isend, - send_tensor, - send_dst, - cp_group) - recv_op = torch.distributed.P2POp(torch.distributed.irecv, - recv_tensor, - recv_src, - cp_group) + send_op = torch.distributed.P2POp( + torch.distributed.isend, send_tensor, send_dst, cp_group + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, recv_src, cp_group + ) send_recv_ops.append(send_op) send_recv_ops.append(recv_op) else: - recv_op = torch.distributed.P2POp(torch.distributed.irecv, - recv_tensor, - recv_src, - cp_group) - send_op = torch.distributed.P2POp(torch.distributed.isend, - send_tensor, - send_dst, - cp_group) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, recv_src, cp_group + ) + send_op = torch.distributed.P2POp( + torch.distributed.isend, send_tensor, send_dst, cp_group + ) send_recv_ops.append(recv_op) send_recv_ops.append(send_op) send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops) @@ -507,12 +515,11 @@ def flash_attn_p2p_communicate(rank, send_tensor, send_dst, @jit_fuser -def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, - softmax_lse, softmax_lse_per_step): +def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): """Merge partial outputs of each step in Attention with context parallelism""" softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) - out_corrected = out_per_step*softmax_lse_corrected_exp + out_corrected = out_per_step * softmax_lse_corrected_exp out.add_(out_corrected) @@ -533,10 +540,32 @@ class AttnFuncWithCP(torch.autograd.Function): """ @staticmethod - def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, dropout_p, - cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format, attn_mask_type, - attn_bias_type, attn_bias, deterministic, use_fused_attention): + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + dropout_p, + cp_group, + cp_global_ranks, + cp_stream, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -546,35 +575,36 @@ class AttnFuncWithCP(torch.autograd.Function): recv_src = cp_global_ranks[(rank - 1) % cp_size] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - causal = ("causal" in attn_mask_type) - padding = ("padding" in attn_mask_type) + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format if causal: if qkv_format == "bshd": # [b, s, np, hn] -> [b, 2, s//2, np, hn] - q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]] + q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]] elif qkv_format == "sbhd": # [s, b, np, hn] -> [2, s//2, b, np, hn] - q, k, v = [x.view(2, x.shape[0]//2, *x.shape[1:]) for x in [q, k, v]] + q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] if attn_bias is not None: - assert (len(attn_bias.shape) == 4), ( + assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " "and [1, h, sq, sk] for backward!" ) # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] - attn_bias_ = attn_bias.view( \ - *attn_bias.shape[:-2], \ - 2, attn_bias.shape[-2]//2, \ - 2*cp_size, attn_bias.shape[-1]//(2*cp_size) \ + attn_bias_ = attn_bias.view( + *attn_bias.shape[:-2], + 2, + attn_bias.shape[-2] // 2, + 2 * cp_size, + attn_bias.shape[-1] // (2 * cp_size), ) # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)] - attn_bias = attn_bias.view( \ - *attn_bias.shape[:-1], \ - 2*cp_size, attn_bias.shape[-1]//(2*cp_size) \ + attn_bias = attn_bias.view( + *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) - assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8" + assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" fa_optional_forward_kwargs = {} if _flash_attn_2_3_plus: fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1] @@ -600,268 +630,387 @@ class AttnFuncWithCP(torch.autograd.Function): p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] - for i in range(cp_size+1): + for i in range(cp_size + 1): if i < cp_size: - with torch.cuda.stream(flash_attn_streams[i%2]): + with torch.cuda.stream(flash_attn_streams[i % 2]): # wait until KV is received - for req in send_recv_reqs[(i+1)%2]: + for req in send_recv_reqs[(i + 1) % 2]: req.wait() - if i < (cp_size-1): - p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i]) - send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank, - p2p_comm_buffers[i], - send_dst, - p2p_comm_buffers[i+1], - recv_src, - cp_group, - batch_p2p_comm) - - kv_inputs[i%2] = p2p_comm_buffers[i] + if i < (cp_size - 1): + p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i]) + send_recv_reqs[i % 2] = flash_attn_p2p_communicate( + rank, + p2p_comm_buffers[i], + send_dst, + p2p_comm_buffers[i + 1], + recv_src, + cp_group, + batch_p2p_comm, + ) + + kv_inputs[i % 2] = p2p_comm_buffers[i] if causal: if i == 0: if use_fused_attention: if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:]) + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view( - 2, k.shape[0], -1, *k.shape[-2:]) + kv_inputs[i % 2] = kv_inputs[i % 2].view( + 2, k.shape[0], -1, *k.shape[-2:] + ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i%2] = q.view(-1, *q.shape[-3:]) + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view( - 2, -1, *k.shape[-3:]) + kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:]) elif qkv_format == "thd": - q_inputs[i%2] = q + q_inputs[i % 2] = q if attn_bias is not None: idx = (rank - i) % cp_size - attn_bias_inputs[i%2] = torch.cat( - (attn_bias[..., idx, :], \ - attn_bias[..., (2*cp_size-idx-1), :]), - dim=-1 + attn_bias_inputs[i % 2] = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \ - fused_attn_fwd( - is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q, - cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0], - kv_inputs[i%2][1], TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2], - seq_offsets_q=seq_offsets_q, seq_offsets_k=seq_offsets_k, - seq_offsets_v=seq_offsets_v, seq_offsets_o=seq_offsets_o, + out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( + fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_k, + cu_seqlens_q, + cu_seqlens_k, + q_inputs[i % 2], + kv_inputs[i % 2][0], + kv_inputs[i % 2][1], + TE_DType[q.dtype], + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + seq_offsets_q=seq_offsets_q, + seq_offsets_k=seq_offsets_k, + seq_offsets_v=seq_offsets_v, + seq_offsets_o=seq_offsets_o, + ) ) if len(rest) > 0: attn_biases[i] = rest[0] else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_inputs[i%2] = q.view(-1, *q.shape[-2:]) + q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) - _, _, _, _, out_per_step[i], \ - softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( - q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal=True, return_softmax=False, - **fa_optional_forward_kwargs + kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + ( + _, + _, + _, + _, + out_per_step[i], + softmax_lse_per_step[i], + _, + rng_states[i], + ) = _flash_attn_forward( + q_inputs[i % 2], + kv_inputs[i % 2][0], + kv_inputs[i % 2][1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=True, + return_softmax=False, + **fa_optional_forward_kwargs, ) elif i <= rank: if use_fused_attention: if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:]) + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() + kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i%2] = q.view(-1, *q.shape[-3:]) + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn] - kv_inputs[i%2] = kv_inputs[i%2][:, 0, ...].contiguous() + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous() elif qkv_format == "thd": - q_inputs[i%2] = q + q_inputs[i % 2] = q # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i%2] = tex.thd_read_half_tensor( - kv_inputs[i%2], cu_seqlens_k, 0) + kv_inputs[i % 2] = tex.thd_read_half_tensor( + kv_inputs[i % 2], cu_seqlens_k, 0 + ) if attn_bias is not None: idx = (rank - i) % cp_size - attn_bias_inputs[i%2] = attn_bias[..., idx, :].contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \ - fused_attn_fwd( - is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q, - cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0], - kv_inputs[i%2][1], TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i%2], - seq_offsets_q=seq_offsets_q, - seq_offsets_k=None if seq_offsets_k is None \ - else seq_offsets_k//2, - seq_offsets_v=None if seq_offsets_v is None \ - else seq_offsets_v//2, - seq_offsets_o=seq_offsets_o, + attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() + out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( + fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_k // 2, + cu_seqlens_q, + cu_seqlens_k // 2, + q_inputs[i % 2], + kv_inputs[i % 2][0], + kv_inputs[i % 2][1], + TE_DType[q.dtype], + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + seq_offsets_q=seq_offsets_q, + seq_offsets_k=( + None if seq_offsets_k is None else seq_offsets_k // 2 + ), + seq_offsets_v=( + None if seq_offsets_v is None else seq_offsets_v // 2 + ), + seq_offsets_o=seq_offsets_o, + ) ) if len(rest) > 0: attn_biases[i] = rest[0] else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_inputs[i%2] = q.view(-1, *q.shape[-2:]) + q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) if qkv_format == "thd": # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i%2] = tex.thd_read_half_tensor( - kv_inputs[i%2], cu_seqlens_k, 0) + kv_inputs[i % 2] = tex.thd_read_half_tensor( + kv_inputs[i % 2], cu_seqlens_k, 0 + ) else: # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() + kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) + kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: fa_optional_forward_kwargs["window_size"] = [-1, -1] - _, _, _, _, out_per_step[i], \ - softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( - q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], - cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2, - dropout_p, softmax_scale, causal=False, return_softmax=False, - **fa_optional_forward_kwargs + ( + _, + _, + _, + _, + out_per_step[i], + softmax_lse_per_step[i], + _, + rng_states[i], + ) = _flash_attn_forward( + q_inputs[i % 2], + kv_inputs[i % 2][0], + kv_inputs[i % 2][1], + cu_seqlens_q, + cu_seqlens_k // 2, + max_seqlen_q, + max_seqlen_k // 2, + dropout_p, + softmax_scale, + causal=False, + return_softmax=False, + **fa_optional_forward_kwargs, ) else: if use_fused_attention: if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_inputs[i%2] = q[:, 1, ...].contiguous() + q_inputs[i % 2] = q[:, 1, ...].contiguous() # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view( - 2, k.shape[0], -1, *k.shape[-2:]) + kv_inputs[i % 2] = kv_inputs[i % 2].view( + 2, k.shape[0], -1, *k.shape[-2:] + ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_inputs[i%2] = q[1].contiguous() + q_inputs[i % 2] = q[1].contiguous() # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view( - 2, -1, *k.shape[-3:]) + kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:]) elif qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) if attn_bias is not None: idx = (rank - i) % cp_size - attn_bias_inputs[i%2] = torch.cat( - (attn_bias_[..., 1, :, idx, :], \ - attn_bias_[..., 1, :, (2*cp_size-idx-1), :]), - dim=-1 + attn_bias_inputs[i % 2] = torch.cat( + ( + attn_bias_[..., 1, :, idx, :], + attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], + ), + dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \ - fused_attn_fwd( - is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2, - cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0], - kv_inputs[i%2][1], TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i%2], - seq_offsets_q=None if seq_offsets_q is None \ - else seq_offsets_q//2, - seq_offsets_k=seq_offsets_k, - seq_offsets_v=seq_offsets_v, - seq_offsets_o=None if seq_offsets_o is None \ - else seq_offsets_o//2, + out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( + fused_attn_fwd( + is_training, + max_seqlen_q // 2, + max_seqlen_k, + cu_seqlens_q // 2, + cu_seqlens_k, + q_inputs[i % 2], + kv_inputs[i % 2][0], + kv_inputs[i % 2][1], + TE_DType[q.dtype], + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + seq_offsets_q=( + None if seq_offsets_q is None else seq_offsets_q // 2 + ), + seq_offsets_k=seq_offsets_k, + seq_offsets_v=seq_offsets_v, + seq_offsets_o=( + None if seq_offsets_o is None else seq_offsets_o // 2 + ), + ) ) if len(rest) > 0: attn_biases[i] = rest[0] else: if qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) else: # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] - q_inputs[i%2] = \ + q_inputs[i % 2] = ( q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) + ) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) + kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: fa_optional_forward_kwargs["window_size"] = [-1, -1] - _, _, _, _, out_per_step[i], \ - softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( - q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], - cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k, - dropout_p, softmax_scale, causal=False, return_softmax=False, - **fa_optional_forward_kwargs + ( + _, + _, + _, + _, + out_per_step[i], + softmax_lse_per_step[i], + _, + rng_states[i], + ) = _flash_attn_forward( + q_inputs[i % 2], + kv_inputs[i % 2][0], + kv_inputs[i % 2][1], + cu_seqlens_q // 2, + cu_seqlens_k, + max_seqlen_q // 2, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=False, + return_softmax=False, + **fa_optional_forward_kwargs, ) else: if use_fused_attention: if attn_bias is not None: idx = (rank - i) % cp_size - attn_bias_inputs[i%2] = torch.cat( - (attn_bias[..., idx, :], attn_bias[..., (2*cp_size-idx-1), :]), - dim=-1 + attn_bias_inputs[i % 2] = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \ - fused_attn_fwd( - is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q, - cu_seqlens_k, q, kv_inputs[i%2][0], - kv_inputs[i%2][1], TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2], - seq_offsets_q=seq_offsets_q, seq_offsets_k=seq_offsets_k, - seq_offsets_v=seq_offsets_v, seq_offsets_o=seq_offsets_o, + out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( + fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_k, + cu_seqlens_q, + cu_seqlens_k, + q, + kv_inputs[i % 2][0], + kv_inputs[i % 2][1], + TE_DType[q.dtype], + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + seq_offsets_q=seq_offsets_q, + seq_offsets_k=seq_offsets_k, + seq_offsets_v=seq_offsets_v, + seq_offsets_o=seq_offsets_o, + ) ) if len(rest) > 0: attn_biases[i] = rest[0] else: # [b, sq, np, hn] -> [b*sq, np, hn] - q_inputs[i%2] = q.view(-1, *q.shape[-2:]) + q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) # [2, b, sk, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) - _, _, _, _, out_per_step[i], \ - softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( - q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal=False, return_softmax=False, - **fa_optional_forward_kwargs + kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + ( + _, + _, + _, + _, + out_per_step[i], + softmax_lse_per_step[i], + _, + rng_states[i], + ) = _flash_attn_forward( + q_inputs[i % 2], + kv_inputs[i % 2][0], + kv_inputs[i % 2][1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=False, + return_softmax=False, + **fa_optional_forward_kwargs, ) if i > 0: # wait until fwd restuls correction of last step is done if i > 1: - flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done) + flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) if use_fused_attention: # [b, np, sq, 1] -> [b, np, sq] - softmax_lse_per_step[i-1].squeeze_(-1) + softmax_lse_per_step[i - 1].squeeze_(-1) - with torch.cuda.stream(flash_attn_streams[(i-1)%2]): + with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if i == 1: out = torch.empty_like(q).zero_() softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": # [b, np, sq] -> [b, np, 2, sq//2] softmax_lse_ = softmax_lse.view( - *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2 + *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 ) - elif (i-1) <= rank or not causal: - flash_attn_fwd_softmax_lse_correction(softmax_lse, - softmax_lse_per_step[i-1]) + elif (i - 1) <= rank or not causal: + flash_attn_fwd_softmax_lse_correction( + softmax_lse, softmax_lse_per_step[i - 1] + ) else: if qkv_format == "thd": - tex.thd_second_half_lse_correction(softmax_lse, - softmax_lse_per_step[i-1], - cu_seqlens_q, - q.size(0)) + tex.thd_second_half_lse_correction( + softmax_lse, softmax_lse_per_step[i - 1], cu_seqlens_q, q.size(0) + ) else: - flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :], - softmax_lse_per_step[i-1]) + flash_attn_fwd_softmax_lse_correction( + softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1] + ) if i < cp_size: - flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done) + flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) @@ -878,34 +1027,42 @@ class AttnFuncWithCP(torch.autograd.Function): if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: - flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), - out_per_step[i], - seq_dim, - softmax_lse, - softmax_lse_per_step[i]) + flash_attn_fwd_out_correction( + out.view(*out_per_step[i].shape), + out_per_step[i], + seq_dim, + softmax_lse, + softmax_lse_per_step[i], + ) elif qkv_format == "thd": - tex.thd_out_correction(out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - cu_seqlens_q, - False) + tex.thd_out_correction( + out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q, + False, + ) else: assert False, f"{qkv_format} is an unsupported qkv_format!" else: if qkv_format in ["bshd", "sbhd"]: - flash_attn_fwd_out_correction(out_, - out_per_step[i], - seq_dim, - softmax_lse_[..., 1, :], - softmax_lse_per_step[i]) + flash_attn_fwd_out_correction( + out_, + out_per_step[i], + seq_dim, + softmax_lse_[..., 1, :], + softmax_lse_per_step[i], + ) elif qkv_format == "thd": - tex.thd_out_correction(out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - cu_seqlens_q, - True) + tex.thd_out_correction( + out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q, + True, + ) else: assert False, f"{qkv_format} is an unsupported qkv_format!" @@ -919,9 +1076,18 @@ class AttnFuncWithCP(torch.autograd.Function): out = out.view(-1, *out.shape[-2:]) ctx.save_for_backward( - q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - *rng_states, *attn_biases + q, + kv, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + *rng_states, + *attn_biases, ) ctx.cp_group = cp_group ctx.cp_global_ranks = cp_global_ranks @@ -942,28 +1108,26 @@ class AttnFuncWithCP(torch.autograd.Function): (q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6] (seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o) = ctx.saved_tensors[6:10] cp_size = get_distributed_world_size(ctx.cp_group) - rng_states = ctx.saved_tensors[10:10+cp_size] - attn_biases = ctx.saved_tensors[10+cp_size:10+cp_size*2] + rng_states = ctx.saved_tensors[10 : 10 + cp_size] + attn_biases = ctx.saved_tensors[10 + cp_size : 10 + cp_size * 2] rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - causal = ("causal" in ctx.attn_mask_type) - padding = ("padding" in ctx.attn_mask_type) + causal = "causal" in ctx.attn_mask_type + padding = "padding" in ctx.attn_mask_type qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if attn_biases[0] is not None: # [b, np, sq, 2*cp, sk//(2*cp)] attn_dbias = torch.zeros( - *ctx.attn_bias_shape, - dtype=attn_biases[0].dtype, - device=attn_biases[0].device + *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device ) # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] attn_dbias_ = attn_dbias.view( - *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3]//2, *attn_dbias.shape[-2:] + *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] ) else: attn_dbias = None @@ -973,8 +1137,9 @@ class AttnFuncWithCP(torch.autograd.Function): softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0)) else: # [b, np, sq] -> [b, np, 2, sq//2] - softmax_lse_ = \ - softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) + softmax_lse_ = softmax_lse.view( + *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 + ) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() if ctx.use_fused_attention: # [b, np, sq//2] -> [b, np, sq//2, 1] @@ -988,8 +1153,10 @@ class AttnFuncWithCP(torch.autograd.Function): # Flash Attn outputs dq = torch.empty_like(q) - p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \ - torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)] + p2p_comm_buffers = [ + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + ] p2p_comm_buffers[0][0].copy_(kv) send_recv_reqs = [] @@ -1004,27 +1171,23 @@ class AttnFuncWithCP(torch.autograd.Function): for req in send_recv_reqs: req.wait() - send_tensor = p2p_comm_buffers[i%2] - recv_tensor = p2p_comm_buffers[(i+1)%2] + send_tensor = p2p_comm_buffers[i % 2] + recv_tensor = p2p_comm_buffers[(i + 1) % 2] if i == 0: send_tensor = send_tensor[0] recv_tensor = recv_tensor[0] - if i == (cp_size-1): + if i == (cp_size - 1): send_tensor = send_tensor[1] recv_tensor = recv_tensor[1] - send_recv_reqs = flash_attn_p2p_communicate(rank, - send_tensor, - send_dst, - recv_tensor, - recv_src, - ctx.cp_group, - batch_p2p_comm) + send_recv_reqs = flash_attn_p2p_communicate( + rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm + ) - kv = p2p_comm_buffers[i%2][0] + kv = p2p_comm_buffers[i % 2][0] # In reversed order of fwd if causal: - if i == (cp_size-1): + if i == (cp_size - 1): if ctx.use_fused_attention: if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] @@ -1044,16 +1207,27 @@ class AttnFuncWithCP(torch.autograd.Function): dout_ = dout.view(-1, *dout.shape[-3:]) elif ctx.qkv_format == "thd": q_, kv_, out_, dout_ = q, kv, out, dout - aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size-i-1]] + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q, ctx.max_seqlen_k, - cu_seqlens_q, cu_seqlens_k, - q_, kv_[0], kv_[1], out_, dout_, - TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + cu_seqlens_q, + cu_seqlens_k, + q_, + kv_[0], + kv_[1], + out_, + dout_, + TE_DType[q.dtype], + TE_DType[kv.dtype], + aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -1073,14 +1247,26 @@ class AttnFuncWithCP(torch.autograd.Function): if _flash_attn_2_3_plus: fa_optional_backward_kwargs["window_size"] = [-1, 0] _flash_attn_backward( - dout_, q_, kv_[0], kv_[1], out_, softmax_lse, - dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k, - ctx.max_seqlen_q, ctx.max_seqlen_k, - ctx.dropout_p, ctx.softmax_scale, True, - rng_state=rng_states[cp_size-i-1], - **fa_optional_backward_kwargs + dout_, + q_, + kv_[0], + kv_[1], + out_, + softmax_lse, + dq_, + dkv_[0], + dkv_[1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + True, + rng_state=rng_states[cp_size - i - 1], + **fa_optional_backward_kwargs, ) - elif i >= (cp_size-rank-1): + elif i >= (cp_size - rank - 1): if ctx.use_fused_attention: if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] @@ -1102,17 +1288,27 @@ class AttnFuncWithCP(torch.autograd.Function): q_, out_, dout_ = q, out, dout # [2, t, np, hn] -> [2, t/2, np, hn] kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) - aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size-i-1]] + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q, ctx.max_seqlen_k//2, - cu_seqlens_q, cu_seqlens_k//2, - q_, kv_[0], kv_[1], out_, dout_, - TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors, + ctx.max_seqlen_q, + ctx.max_seqlen_k // 2, + cu_seqlens_q, + cu_seqlens_k // 2, + q_, + kv_[0], + kv_[1], + out_, + dout_, + TE_DType[q.dtype], + TE_DType[kv.dtype], + aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - seq_offsets_q, None if seq_offsets_k is None else seq_offsets_k//2, - None if seq_offsets_v is None else seq_offsets_v//2, seq_offsets_o, + seq_offsets_q, + None if seq_offsets_k is None else seq_offsets_k // 2, + None if seq_offsets_v is None else seq_offsets_v // 2, + seq_offsets_o, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -1136,12 +1332,24 @@ class AttnFuncWithCP(torch.autograd.Function): if _flash_attn_2_3_plus: fa_optional_backward_kwargs["window_size"] = [-1, -1] _flash_attn_backward( - dout_, q_, kv_[0], kv_[1], out_, softmax_lse, - dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2, - ctx.max_seqlen_q, ctx.max_seqlen_k//2, - ctx.dropout_p, ctx.softmax_scale, False, - rng_state=rng_states[cp_size-i-1], - **fa_optional_backward_kwargs + dout_, + q_, + kv_[0], + kv_[1], + out_, + softmax_lse, + dq_, + dkv_[0], + dkv_[1], + cu_seqlens_q, + cu_seqlens_k // 2, + ctx.max_seqlen_q, + ctx.max_seqlen_k // 2, + ctx.dropout_p, + ctx.softmax_scale, + False, + rng_state=rng_states[cp_size - i - 1], + **fa_optional_backward_kwargs, ) else: if ctx.use_fused_attention: @@ -1167,17 +1375,27 @@ class AttnFuncWithCP(torch.autograd.Function): out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1) dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1) kv_ = kv - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size-i-1]] + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q//2, ctx.max_seqlen_k, - cu_seqlens_q//2, cu_seqlens_k, - q_, kv_[0], kv_[1], out_, dout_, - TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors, + ctx.max_seqlen_q // 2, + ctx.max_seqlen_k, + cu_seqlens_q // 2, + cu_seqlens_k, + q_, + kv_[0], + kv_[1], + out_, + dout_, + TE_DType[q.dtype], + TE_DType[kv.dtype], + aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - None if seq_offsets_q is None else seq_offsets_q//2, seq_offsets_k, - seq_offsets_v, None if seq_offsets_o is None else seq_offsets_o//2, + None if seq_offsets_q is None else seq_offsets_q // 2, + seq_offsets_k, + seq_offsets_v, + None if seq_offsets_o is None else seq_offsets_o // 2, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -1205,25 +1423,48 @@ class AttnFuncWithCP(torch.autograd.Function): if _flash_attn_2_3_plus: fa_optional_backward_kwargs["window_size"] = [-1, -1] _flash_attn_backward( - dout_, q_, kv_[0], kv_[1], out_, softmax_lse_, - dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k, - ctx.max_seqlen_q//2, ctx.max_seqlen_k, - ctx.dropout_p, ctx.softmax_scale, False, - rng_state=rng_states[cp_size-i-1], - **fa_optional_backward_kwargs + dout_, + q_, + kv_[0], + kv_[1], + out_, + softmax_lse_, + dq_, + dkv_[0], + dkv_[1], + cu_seqlens_q // 2, + cu_seqlens_k, + ctx.max_seqlen_q // 2, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + False, + rng_state=rng_states[cp_size - i - 1], + **fa_optional_backward_kwargs, ) else: if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size-i-1]] + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q, ctx.max_seqlen_k, - cu_seqlens_q, cu_seqlens_k, - q, kv[0], kv[1], out, dout, - TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + cu_seqlens_q, + cu_seqlens_k, + q, + kv[0], + kv[1], + out, + dout, + TE_DType[q.dtype], + TE_DType[kv.dtype], + aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -1243,14 +1484,26 @@ class AttnFuncWithCP(torch.autograd.Function): if _flash_attn_2_3_plus: fa_optional_backward_kwargs["window_size"] = [-1, -1] _flash_attn_backward( - dout_, q_, kv_[0], kv_[1], out_, softmax_lse, - dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k, - ctx.max_seqlen_q, ctx.max_seqlen_k, - ctx.dropout_p, ctx.softmax_scale, False, - **fa_optional_backward_kwargs + dout_, + q_, + kv_[0], + kv_[1], + out_, + softmax_lse, + dq_, + dkv_[0], + dkv_[1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + False, + **fa_optional_backward_kwargs, ) - if i >= (cp_size-rank-1) or not causal: + if i >= (cp_size - rank - 1) or not causal: # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal # [b*sq, np, hn] -> [b, sq, np, hn] if not causal dq_ = dq_.view(*dq.shape) @@ -1263,10 +1516,10 @@ class AttnFuncWithCP(torch.autograd.Function): dq_ = dq_.view(-1, *dq.shape[-3:]) if causal: - if i > (cp_size-rank-1): + if i > (cp_size - rank - 1): dq.add_(dq_) - elif i == (cp_size-rank-1): - if rank == (cp_size-1): + elif i == (cp_size - rank - 1): + if rank == (cp_size - 1): dq.copy_(dq_) else: if ctx.qkv_format == "bshd": @@ -1298,29 +1551,29 @@ class AttnFuncWithCP(torch.autograd.Function): dq.add_(dq_) if attn_dbias is not None: - idx = (rank+i+1)%cp_size + idx = (rank + i + 1) % cp_size if i == (cp_size - 1) or not causal: # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] - dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2) + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) - attn_dbias[..., (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :]) - elif i >= (cp_size-rank-1): + attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) + elif i >= (cp_size - rank - 1): # [b, np, sq, sk//(2*cp)] attn_dbias[..., idx, :].copy_(dbias_) else: # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] - dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2) + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) - attn_dbias_[..., 1, :, (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :]) + attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) # wait until dKV is received for req in send_recv_reqs: req.wait() - dkv = p2p_comm_buffers[(i+1)%2][1] + dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) - if causal and i >= (cp_size-rank-1) and i != (cp_size-1): + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): if ctx.qkv_format == "bshd": # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:]) @@ -1333,7 +1586,7 @@ class AttnFuncWithCP(torch.autograd.Function): dkv_ = dkv_.view(*dkv.shape) if causal: - if i == (cp_size-1): + if i == (cp_size - 1): if rank == 0: if ctx.qkv_format == "bshd": dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) @@ -1345,8 +1598,8 @@ class AttnFuncWithCP(torch.autograd.Function): tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy") else: dkv.add_(dkv_) - elif i >= (cp_size-rank-1): - if i == 0 and rank == (cp_size-1): + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): if ctx.qkv_format == "bshd": dkv[:, :, 0, ...].copy_(dkv_) elif ctx.qkv_format == "sbhd": @@ -1386,35 +1639,103 @@ class AttnFuncWithCP(torch.autograd.Function): # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) - return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None, attn_dbias, None, None + return ( + None, + dq, + dkv[0], + dkv[1], + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + attn_dbias, + None, + None, + ) def attn_forward_func_with_cp( - is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, dropout_p, - cp_group, cp_global_ranks, cp_stream, softmax_scale=None, qkv_format="bshd", - attn_mask_type="causal", attn_bias_type="no_bias", attn_bias=None, deterministic=False, - use_fused_attention=False + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + dropout_p, + cp_group, + cp_global_ranks, + cp_stream, + softmax_scale=None, + qkv_format="bshd", + attn_mask_type="causal", + attn_bias_type="no_bias", + attn_bias=None, + deterministic=False, + use_fused_attention=False, ) -> torch.Tensor: """Attention implementation with context parallelism""" - assert(qkv_format in ["bshd", "sbhd", "thd"] - ), f"QKV format of {qkv_format} is not supported with context parallelism!" - assert(qkv_format != "sbhd" or use_fused_attention - ), "FlashAttention does not support sbhd format!" - assert (qkv_format != 'thd' or \ - not use_fused_attention or \ - attn_mask_type in ["padding", "padding_causal"] - ), f"Context parallelism is not supported for {attn_mask_type} mask type and " \ - f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!" - assert (attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type) - ), """Attention bias is only supported with FusedAttention and "causal" """ \ - """or "no_mask" mask types!""" + assert qkv_format in [ + "bshd", + "sbhd", + "thd", + ], f"QKV format of {qkv_format} is not supported with context parallelism!" + assert ( + qkv_format != "sbhd" or use_fused_attention + ), "FlashAttention does not support sbhd format!" + assert ( + qkv_format != "thd" + or not use_fused_attention + or attn_mask_type in ["padding", "padding_causal"] + ), ( + f"Context parallelism is not supported for {attn_mask_type} mask type and " + f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!" + ) + assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( + """Attention bias is only supported with FusedAttention and "causal" """ + """or "no_mask" mask types!""" + ) out = AttnFuncWithCP.apply( - is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, dropout_p, - cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format, attn_mask_type, - attn_bias_type, attn_bias, deterministic, use_fused_attention + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + dropout_p, + cp_group, + cp_global_ranks, + cp_stream, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, ) return out @@ -1423,6 +1744,7 @@ class RotaryPositionEmbedding(torch.nn.Module): """ Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. """ + def __init__( self, dim: int, @@ -1454,7 +1776,7 @@ class RotaryPositionEmbedding(torch.nn.Module): / dim ) ) - self.register_buffer('inv_freq', inv_freq) + self.register_buffer("inv_freq", inv_freq) self.pretrained_max_position_embeddings = pretrained_max_position_embeddings def forward(self, max_seq_len: int, offset: int = 0): @@ -1473,17 +1795,21 @@ class RotaryPositionEmbedding(torch.nn.Module): + offset ) - if (self.pretrained_max_position_embeddings is not None - and self.seq_len_interpolation_factor is not None): - if (max_seq_len > - self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor): + if ( + self.pretrained_max_position_embeddings is not None + and self.seq_len_interpolation_factor is not None + ): + if ( + max_seq_len + > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor + ): # dynamic linear scaling (length > position we have learned) seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) else: # fixed linear scaling seq *= 1 / self.seq_len_interpolation_factor - freqs = torch.einsum('i , j -> i j', seq, self.inv_freq) + freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) # first part even vector components, second part odd vector components, # 2 * dim in dimension size emb = torch.cat((freqs, freqs), dim=-1) @@ -1513,9 +1839,7 @@ class FusedRoPEFunc(torch.autograd.Function): if tensor_format == "sbhd": output = tex.fused_rope_forward(t, freqs, False) elif tensor_format == "bshd": - output = tex.fused_rope_forward( - t.transpose(0, 1), freqs, True - ).transpose(0, 1) + output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) elif tensor_format == "thd": output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs) else: @@ -1526,9 +1850,7 @@ class FusedRoPEFunc(torch.autograd.Function): return output @staticmethod - def backward( - ctx, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: freqs, cu_seqlens = ctx.saved_tensors if ctx.tensor_format == "sbhd": grad_input = tex.fused_rope_backward(grad_output, freqs, False) @@ -1596,9 +1918,9 @@ def apply_rotary_pos_emb( # Only apply the rotary embeddings up to the sequence length of the running # input. - assert cur_seq_len <= max_seq_len, ( - f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - ) + assert ( + cur_seq_len <= max_seq_len + ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" freqs = freqs[:cur_seq_len] if tensor_format == "bshd": freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] @@ -1620,32 +1942,37 @@ class _SplitAlongDim(torch.autograd.Function): """""" @staticmethod - def forward(ctx, - mixed_x_layer: torch.Tensor, - split_dim: int, - split_size_or_sections: Union[int, List[int], Tuple[int]], + def forward( + ctx, + mixed_x_layer: torch.Tensor, + split_dim: int, + split_size_or_sections: Union[int, List[int], Tuple[int]], ) -> Tuple[torch.Tensor, ...]: ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections if isinstance(mixed_x_layer, Float8Tensor): - return tuple(Float8Tensor.make_like( - mixed_x_layer, - data=x, - ) for x in torch.split( + return tuple( + Float8Tensor.make_like( + mixed_x_layer, + data=x, + ) + for x in torch.split( mixed_x_layer._data, split_size_or_sections=split_size_or_sections, - dim=split_dim)) - return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim) + dim=split_dim, + ) + ) + return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim) @staticmethod - def backward(ctx, - *grad_outputs): + def backward(ctx, *grad_outputs): assert len(grad_outputs) > 0, "No gradients received for backprop!" if isinstance(ctx.split_size_or_sections, (list, tuple)): split_sizes = ctx.split_size_or_sections - assert (len(grad_outputs) == len(split_sizes) - ), "Unequal number of gradients vs split sections for backprop!" + assert len(grad_outputs) == len( + split_sizes + ), "Unequal number of gradients vs split sections for backprop!" if isinstance(ctx.split_size_or_sections, int): split_sizes = [ctx.split_size_or_sections] * len(grad_outputs) dims = len(grad_outputs[0].shape) @@ -1659,29 +1986,37 @@ class _SplitAlongDim(torch.autograd.Function): for i, tensor in enumerate(grad_outputs): shape_i = shape shape_i[split_dim] = split_sizes[i] - offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:]) - if (tensor.stride() != strides or - list(tensor.shape) != shape_i or - tensor._data.untyped_storage().data_ptr() != data_ptr or - tensor.storage_offset() != offset_size): + offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :]) + if ( + tensor.stride() != strides + or list(tensor.shape) != shape_i + or tensor._data.untyped_storage().data_ptr() != data_ptr + or tensor.storage_offset() != offset_size + ): noop_ok = False break if noop_ok: - ret = torch.Tensor().to(device=grad_outputs[0].device, - dtype=grad_outputs[0]._data.dtype) + ret = torch.Tensor().to( + device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype + ) new_shape = list(shape) new_shape[split_dim] = sum(split_sizes) - ret.set_(grad_outputs[0]._data.untyped_storage(), - grad_outputs[0]._data.storage_offset(), - new_shape, - strides + ret.set_( + grad_outputs[0]._data.untyped_storage(), + grad_outputs[0]._data.storage_offset(), + new_shape, + strides, ) return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None grad_outputs_data = [x._data for x in grad_outputs] - return Float8Tensor.make_like( - grad_outputs[0], - data=torch.cat(grad_outputs_data, dim = split_dim)), None, None + return ( + Float8Tensor.make_like( + grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim) + ), + None, + None, + ) noop_ok = True strides = grad_outputs[0].stride() data_ptr = grad_outputs[0].untyped_storage().data_ptr() @@ -1689,26 +2024,28 @@ class _SplitAlongDim(torch.autograd.Function): for i, tensor in enumerate(grad_outputs): shape_i = shape shape_i[split_dim] = split_sizes[i] - offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:]) - if (tensor.stride() != strides or - list(tensor.shape) != shape_i or - tensor.untyped_storage().data_ptr() != data_ptr or - tensor.storage_offset() != offset_size): + offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :]) + if ( + tensor.stride() != strides + or list(tensor.shape) != shape_i + or tensor.untyped_storage().data_ptr() != data_ptr + or tensor.storage_offset() != offset_size + ): noop_ok = False break if noop_ok: - ret = torch.Tensor().to(device=grad_outputs[0].device, - dtype=grad_outputs[0].dtype) + ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype) new_shape = list(shape) new_shape[split_dim] = sum(split_sizes) - ret.set_(grad_outputs[0].untyped_storage(), - grad_outputs[0].storage_offset(), - new_shape, - strides + ret.set_( + grad_outputs[0].untyped_storage(), + grad_outputs[0].storage_offset(), + new_shape, + strides, ) return ret, None, None - return torch.cat(grad_outputs, dim = split_dim), None, None + return torch.cat(grad_outputs, dim=split_dim), None, None class UnfusedDotProductAttention(torch.nn.Module): @@ -1738,7 +2075,8 @@ class UnfusedDotProductAttention(torch.nn.Module): # An FP16 training trick required for certain GPT-like models. self.apply_qk_layer_scaling = ( - bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None) + bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None + ) def forward( self, @@ -1746,8 +2084,8 @@ class UnfusedDotProductAttention(torch.nn.Module): key_layer: torch.Tensor, value_layer: torch.Tensor, qkv_layout: str = "sbh3d", - cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument - cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, core_attention_bias_type: str = "no_bias", @@ -1756,13 +2094,15 @@ class UnfusedDotProductAttention(torch.nn.Module): ) -> torch.Tensor: """Unfused attention fprop""" - assert (qkv_layout in QKVLayouts - ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" - qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) - if qkv_format == 'bshd': + assert ( + qkv_layout in QKVLayouts + ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now - query_layer, key_layer, value_layer = [x.transpose(0, 1) - for x in [query_layer, key_layer, value_layer]] + query_layer, key_layer, value_layer = [ + x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] + ] batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -1776,17 +2116,18 @@ class UnfusedDotProductAttention(torch.nn.Module): ) if key_layer.shape[2] != query_layer.shape[2]: - assert (query_layer.shape[2]%key_layer.shape[2]==0 - ),"The number of attention heads must be divisible by the number of GQA groups!" + assert ( + query_layer.shape[2] % key_layer.shape[2] == 0 + ), "The number of attention heads must be divisible by the number of GQA groups!" key_layer = key_layer.repeat_interleave( - int(query_layer.shape[2]/key_layer.shape[2]), dim = 2) + int(query_layer.shape[2] / key_layer.shape[2]), dim=2 + ) value_layer = value_layer.repeat_interleave( - int(query_layer.shape[2]/value_layer.shape[2]), dim = 2) + int(query_layer.shape[2] / value_layer.shape[2]), dim=2 + ) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape( - output_size[2], output_size[0] * output_size[1], -1 - ) + query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) @@ -1824,9 +2165,10 @@ class UnfusedDotProductAttention(torch.nn.Module): query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] ) - matmul_result = (matmul_result.view( - output_size[0], output_size[1], output_size[2], output_size[3]) - + core_attention_bias).view(-1, output_size[2], output_size[3]) + matmul_result = ( + matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3]) + + core_attention_bias + ).view(-1, output_size[2], output_size[3]) matmul_result *= scale elif core_attention_bias_type in ["post_scale_bias", "alibi"]: @@ -1834,7 +2176,8 @@ class UnfusedDotProductAttention(torch.nn.Module): assert core_attention_bias is not None, "core_attention_bias should not be None!" if core_attention_bias_type == "alibi": _, core_attention_bias = get_alibi( - output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes) + output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes + ) matmul_result = torch.baddbmm( matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] @@ -1842,10 +2185,16 @@ class UnfusedDotProductAttention(torch.nn.Module): beta=0.0, alpha=scale, ) - matmul_result = (matmul_result.view( - output_size[0], output_size[1], output_size[2], output_size[3]) - + core_attention_bias).view(-1, output_size[2], output_size[3]).to( - dtype=query_layer.dtype) + matmul_result = ( + ( + matmul_result.view( + output_size[0], output_size[1], output_size[2], output_size[3] + ) + + core_attention_bias + ) + .view(-1, output_size[2], output_size[3]) + .to(dtype=query_layer.dtype) + ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) @@ -1853,7 +2202,8 @@ class UnfusedDotProductAttention(torch.nn.Module): # attention scores and attention mask [b, np, sq, sk] softmax_scale = self.layer_number if apply_qk_layer_scaling else None attention_probs = self.scale_mask_softmax( - attention_scores, attention_mask, attn_mask_type, softmax_scale) + attention_scores, attention_mask, attn_mask_type, softmax_scale + ) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -1870,14 +2220,10 @@ class UnfusedDotProductAttention(torch.nn.Module): ) # change view [sk, b * np, hn] - value_layer = value_layer.reshape( - value_layer.size(0), output_size[0] * output_size[1], -1 - ) + value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] - attention_probs = attention_probs.view( - output_size[0] * output_size[1], output_size[2], -1 - ) + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) @@ -1885,14 +2231,14 @@ class UnfusedDotProductAttention(torch.nn.Module): # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) - if qkv_format == 'sbhd': + if qkv_format == "sbhd": # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] context_layer = context_layer.view(seqlen, batch_size, -1) - if qkv_format == 'bshd': + if qkv_format == "bshd": # [b, np, sq, hn] --> [b, sq, np, hn] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -1904,14 +2250,14 @@ class UnfusedDotProductAttention(torch.nn.Module): class _PrepareQKVForFA(torch.autograd.Function): """This class converts QKV from interleaved (s, b, ...) layout - to separate contiguous q, k, v tensors in (b, s, ...) layout.""" + to separate contiguous q, k, v tensors in (b, s, ...) layout.""" @staticmethod def forward( _ctx: torch.autograd.function.FunctionCtx, # unused query_layer: torch.Tensor, key_layer: torch.Tensor, - value_layer: torch.Tensor + value_layer: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # All inputs received are non-contiguous tensors. # The `query_layer` tensor is used to access the @@ -1928,7 +2274,7 @@ class _PrepareQKVForFA(torch.autograd.Function): _ctx: torch.autograd.function.FunctionCtx, # unused dq: torch.Tensor, dk: torch.Tensor, - dv: torch.Tensor + dv: torch.Tensor, ) -> Tuple[Union[torch.Tensor, None], ...]: dqkv = tex.fa_prepare_bwd(dq, dk, dv) dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) @@ -1936,11 +2282,11 @@ class _PrepareQKVForFA(torch.autograd.Function): def _get_qkv_layout( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qkv_format: str = 'sbhd', - ) -> str: + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv_format: str = "sbhd", +) -> str: """Get qkv layout. Parameters @@ -1992,60 +2338,72 @@ def _get_qkv_layout( check_shapes_kv = all(shape == x.shape for x in [k, v]) last_dim_size = q.shape[-1] - check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset() - for i, x in enumerate([q, k, v])) + check_last_dim_offsets_qkv = all( + i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v]) + ) last_dim_size = k.shape[-1] - check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset() - for i, x in enumerate([k, v])) + check_last_dim_offsets_kv = all( + i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v]) + ) last_two_dims_size = q.shape[-1] * q.shape[-2] - check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset() - for i, x in enumerate([q, k, v])) + check_last_two_dims_offsets_qkv = all( + i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v]) + ) last_two_dims_size = k.shape[-1] * k.shape[-2] - check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset() - for i, x in enumerate([k, v])) + check_last_two_dims_offsets_kv = all( + i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v]) + ) - if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv + if ( + check_ptrs_qkv + and check_strides_qkv + and check_shapes_qkv and check_last_two_dims_offsets_qkv - and not check_last_dim_offsets_qkv): + and not check_last_dim_offsets_qkv + ): # sb3hd, bs3hd, t3hd - qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:] - elif (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv - and check_last_dim_offsets_qkv): + qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:] + elif ( + check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv + ): # sbh3d, bsh3d, th3d - qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:] - elif (check_ptrs_kv and check_strides_kv and check_shapes_kv + qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:] + elif ( + check_ptrs_kv + and check_strides_kv + and check_shapes_kv and check_last_two_dims_offsets_kv - and not check_last_dim_offsets_kv): + and not check_last_dim_offsets_kv + ): # sbhd_sb2hd, bshd_bs2hd, thd_t2hd - qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:] - elif (check_ptrs_kv and check_strides_kv and check_shapes_kv - and check_last_dim_offsets_kv): + qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv: # sbhd_sbh2d, bshd_bsh2d, thd_th2d - qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:] + qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:] elif check_strides_kv and check_shapes_kv: # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd - qkv_layout = '_'.join(list([qkv_format])*3) + qkv_layout = "_".join(list([qkv_format]) * 3) else: - qkv_layout = 'not_supported' + qkv_layout = "not_supported" return qkv_layout qkv_layout = run_iteratively(q, k, v) - if qkv_layout == 'not_supported': + if qkv_layout == "not_supported": # force q,k,v to be contiguous and run get_layout again q, k, v = [x.contiguous() for x in [q, k, v]] qkv_layout = run_iteratively(q, k, v) - if qkv_layout == 'not_supported': + if qkv_layout == "not_supported": raise Exception("The provided qkv memory layout is not supported!") return qkv_layout, q, k, v def check_set_window_size( - attn_mask_type: str, - window_size: Tuple[int, int] = None, - ): + attn_mask_type: str, + window_size: Tuple[int, int] = None, +): """Check if sliding window size is compliant with mask type and if not, assert or set it to the appropriate size """ @@ -2118,36 +2476,40 @@ class FlashAttention(torch.nn.Module): query_layer.dtype in [torch.float16, torch.bfloat16] and key_layer.dtype in [torch.float16, torch.bfloat16] and value_layer.dtype in [torch.float16, torch.bfloat16] - ), "FlashAttention currently only supports FP16 and BF16." + ), "FlashAttention currently only supports FP16 and BF16." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda - ), "FlashAttention currently only supports CUDA tensors." + ), "FlashAttention currently only supports CUDA tensors." assert ( qkv_layout in QKVLayouts - ), f"FlashAttention does not support qkv_layout = {qkv_layout}!" + ), f"FlashAttention does not support qkv_layout = {qkv_layout}!" context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) - qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == 'sbhd': + if qkv_format == "sbhd": # For now just 128, will make it more general in the future - if (query_layer.shape[-1] == 128 and - query_layer.shape[0] * query_layer.shape[1] >= 512 and - qkv_layout == "sbh3d"): - query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer, - key_layer, - value_layer) + if ( + query_layer.shape[-1] == 128 + and query_layer.shape[0] * query_layer.shape[1] >= 512 + and qkv_layout == "sbh3d" + ): + query_layer, key_layer, value_layer = _PrepareQKVForFA.apply( + query_layer, key_layer, value_layer + ) else: - query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() - for x in (query_layer, key_layer, value_layer)] - elif qkv_format in ['bshd', 'thd']: - query_layer, key_layer, value_layer = [x.contiguous() - for x in (query_layer, key_layer, value_layer)] + query_layer, key_layer, value_layer = [ + x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer) + ] + elif qkv_format in ["bshd", "thd"]: + query_layer, key_layer, value_layer = [ + x.contiguous() for x in (query_layer, key_layer, value_layer) + ] batch_size = query_layer.shape[0] - if qkv_format in ['sbhd', 'bshd']: + if qkv_format in ["sbhd", "bshd"]: max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] if not context_parallel: # [b * s, h, d] @@ -2156,7 +2518,7 @@ class FlashAttention(torch.nn.Module): for x in [query_layer, key_layer, value_layer] ] - if 'padding' in attn_mask_type: + if "padding" in attn_mask_type: assert not context_parallel, "Padding mask not supported with context parallelism!" if self.attention_type == "self": @@ -2164,8 +2526,9 @@ class FlashAttention(torch.nn.Module): max_seqlen_q == max_seqlen_kv ), "Maximum sequence length for Q and KV should be the same." if cu_seqlens_q is None: - assert (attention_mask is not None - ), "Please provide attention_mask for padding!" + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask) else: indices_q = get_indices(max_seqlen_q, cu_seqlens_q) @@ -2175,19 +2538,16 @@ class FlashAttention(torch.nn.Module): ) else: if cu_seqlens_q is None or cu_seqlens_kv is None: - assert (attention_mask is not None - ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = get_cu_seqlens_and_indices( - attention_mask[0]) - cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices( - attention_mask[1]) + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0]) + cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1]) else: indices_q = get_indices(max_seqlen_q, cu_seqlens_q) indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) query_layer = PackTensors.apply(indices_q, query_layer) - key_layer, value_layer = PackTensors.apply( - indices_kv, key_layer, value_layer - ) + key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer) else: # Cumulative sequence lengths for unpadded data if cu_seqlens_q is None: @@ -2202,9 +2562,10 @@ class FlashAttention(torch.nn.Module): max_seqlen_kv, key_layer.device, ) - elif qkv_format == 'thd': - assert (cu_seqlens_q is not None and cu_seqlens_kv is not None - ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" + elif qkv_format == "thd": + assert ( + cu_seqlens_q is not None and cu_seqlens_kv is not None + ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" if max_seqlen_q is None: seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] max_seqlen_q = seqlens_q.max().item() @@ -2213,27 +2574,40 @@ class FlashAttention(torch.nn.Module): max_seqlen_kv = seqlens_kv.max().item() if context_parallel: - assert ( - window_size in ((-1, -1), (-1, 0)) - ), "Sliding window attention is not supported with context parallelism." + assert window_size in ( + (-1, -1), + (-1, 0), + ), "Sliding window attention is not supported with context parallelism." assert ( alibi_slopes is None ), "Alibi slope bias addition is not supported with context parallelism." with self.attention_dropout_ctx(): output = attn_forward_func_with_cp( - self.training, query_layer, key_layer, value_layer, - cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - None, None, None, None, + self.training, + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + None, + None, + None, + None, self.attention_dropout if self.training else 0.0, - cp_group, cp_global_ranks, cp_stream, + cp_group, + cp_global_ranks, + cp_stream, softmax_scale=self.softmax_scale, - qkv_format="bshd" if qkv_format=="sbhd" else qkv_format, + qkv_format="bshd" if qkv_format == "sbhd" else qkv_format, attn_mask_type=attn_mask_type, - deterministic=self.deterministic + deterministic=self.deterministic, ) else: from .cpu_offload import CPUOffloadEnabled + if CPUOffloadEnabled: tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv] for tensor in tensor_list: @@ -2249,102 +2623,142 @@ class FlashAttention(torch.nn.Module): if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic output = flash_attn_forward_func( - query_layer, key_layer, value_layer, - cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, self.attention_dropout if self.training else 0.0, - softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, **fa_optional_forward_kwargs, ) - if qkv_format in ['sbhd', 'bshd'] and 'padding' in attn_mask_type: + if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) - if qkv_format == 'sbhd': + if qkv_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous() - elif qkv_format == 'bshd': + elif qkv_format == "bshd": # (bs)hd -> bs(hd) output = output.view(batch_size, max_seqlen_q, -1).contiguous() - elif qkv_format == 'thd': + elif qkv_format == "thd": # thd -> t(hd) output = output.view(output.shape[0], -1).contiguous() return output + def _combine_tensors( - tensors: List[torch.Tensor], - dim: int, - ) -> torch.Tensor: + tensors: List[torch.Tensor], + dim: int, +) -> torch.Tensor: """Combine tensors along a particular dimension""" num_tensors = len(tensors) new_shape = list(tensors[0].shape) new_shape.insert(dim, num_tensors) new_stride = list(tensors[0].stride()) - new_stride.insert(dim, int(new_stride[dim-1]/num_tensors)) + new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) if isinstance(tensors[0], Float8Tensor): - combined_tensor = torch.Tensor().to( - device=tensors[0].device, dtype=tensors[0]._data.dtype) + combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype) combined_tensor.set_( tensors[0]._data.untyped_storage(), tensors[0]._data.storage_offset(), - new_shape, new_stride) - combined_tensor = Float8Tensor.make_like( - tensors[0], data=combined_tensor) + new_shape, + new_stride, + ) + combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor) else: - combined_tensor = torch.Tensor().to( - device=tensors[0].device, dtype=tensors[0].dtype) + combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype) combined_tensor.set_( - tensors[0].untyped_storage(), - tensors[0].storage_offset(), - new_shape, new_stride) + tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride + ) return combined_tensor + class FusedAttnFunc_qkvpacked(torch.autograd.Function): """Function for FusedAttention with packed QKV input""" @staticmethod - def forward(ctx, is_training, max_seqlen, cu_seqlens, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - qkv, qkv_dtype, attn_bias, attn_scale, - dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen, fused_attention_backend, use_FAv2_bwd, - fp8, fp8_meta): + def forward( + ctx, + is_training, + max_seqlen, + cu_seqlens, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + qkv, + qkv_dtype, + attn_bias, + attn_scale, + dropout_p, + fast_zero_fill, + qkv_layout, + attn_bias_type, + attn_mask_type, + rng_gen, + fused_attention_backend, + use_FAv2_bwd, + fp8, + fp8_meta, + ): logger = logging.getLogger("FusedAttnFunc_qkvpacked") if fp8: logger.debug("Running forward in FP8") if fp8_meta["recipe"].fp8_mha: - assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA." + assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split('_')) - assert (qkv_group == 1 - ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, \ - but found {qkv_layout}." + qkv_group = len(qkv_layout.split("_")) + assert qkv_group == 1, ( + "qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found" + f" {qkv_layout}." + ) if fp8_meta["recipe"].fp8_mha: qkv_fp8 = qkv._data else: qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8(qkv_c, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(qkv.shape) + qkv_fp8 = cast_to_fp8( + qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward + ).view(qkv.shape) out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, max_seqlen, cu_seqlens, - qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + is_training, + max_seqlen, + cu_seqlens, + qkv_fp8, + fp8_dtype_forward, + fused_attention_backend, + attn_bias, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_O], fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].amax_history[0][META_O], - attn_scale, dropout_p, fast_zero_fill, qkv_layout, - attn_bias_type, attn_mask_type, rng_gen) + attn_scale, + dropout_p, + fast_zero_fill, + qkv_layout, + attn_bias_type, + attn_mask_type, + rng_gen, + ) if fp8_meta["recipe"].fp8_mha: - out_ret = Float8Tensor(data=out_fp8, + out_ret = Float8Tensor( + data=out_fp8, fp8_meta=fp8_meta, fp8_meta_forward=True, fp8_meta_index=META_O, @@ -2354,38 +2768,77 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): else: out_ret = cast_from_fp8( out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) out_save = out_ret if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv = cast_from_fp8(qkv_c._data, + qkv = cast_from_fp8( + qkv_c._data, fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape) + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) out_save = cast_from_fp8( out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) - fp8_tensors = (qkv_fp8, out_fp8, + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) + fp8_tensors = ( + qkv_fp8, + out_fp8, fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone()) + fp8_meta["scaling_fwd"].scale_inv.clone(), + ) else: - logger.debug("Running forward in %s",qkv.dtype) + logger.debug("Running forward in %s", qkv.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, - fused_attention_backend, attn_bias, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - None, None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) + is_training, + max_seqlen, + cu_seqlens, + qkv, + qkv_dtype, + fused_attention_backend, + attn_bias, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + None, + None, + None, + None, + None, + None, + attn_scale, + dropout_p, + fast_zero_fill, + qkv_layout, + attn_bias_type, + attn_mask_type, + rng_gen, + ) fp8_tensors = (None, None, None, None) out_save = out_ret ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) - ctx.save_for_backward(*qkvo_tensors, cu_seqlens, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - *fp8_tensors, *aux_ctx_tensors) + ctx.save_for_backward( + *qkvo_tensors, + cu_seqlens, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + *fp8_tensors, + *aux_ctx_tensors, + ) ctx.fp8_meta = fp8_meta ctx.max_seqlen = max_seqlen ctx.qkv_dtype = qkv_dtype @@ -2395,8 +2848,9 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type - ctx.fused_attention_backend = \ + ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] + ) ctx.use_FAv2_bwd = use_FAv2_bwd return out_ret @@ -2405,118 +2859,256 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): def backward(ctx, d_out): logger = logging.getLogger("FusedAttnFunc_qkvpacked") if ctx.fp8_meta["recipe"].fp8_mha: - assert (isinstance(d_out, Float8Tensor) - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." + assert isinstance( + d_out, Float8Tensor + ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." d_out_f8tensor = d_out d_out = d_out._data d_out = d_out.contiguous() - (qkv, out, cu_seqlens, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - qkv_fp8, out_fp8, - fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors + ( + qkv, + out, + cu_seqlens, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + qkv_fp8, + out_fp8, + fwd_scales, + fwd_scale_invs, + *aux_ctx_tensors, + ) = ctx.saved_tensors if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: softmax_lse, rng_state = aux_ctx_tensors dqkv = torch.empty_like(qkv) maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - d_out, q, k, v, out = [maybe_contiguous(x) - for x in (d_out, qkv[:,0], qkv[:,1], qkv[:,2], out)] + d_out, q, k, v, out = [ + maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out) + ] flash_attn_cuda_bwd( - d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2], - cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, - ctx.dropout_p, ctx.attn_scale, False, - "causal" in ctx.attn_mask_type, None, rng_state + d_out, + q, + k, + v, + out, + softmax_lse, + dqkv[:, 0], + dqkv[:, 1], + dqkv[:, 2], + cu_seqlens, + cu_seqlens, + ctx.max_seqlen, + ctx.max_seqlen, + ctx.dropout_p, + ctx.attn_scale, + False, + "causal" in ctx.attn_mask_type, + None, + rng_state, ) - dqkv = dqkv[..., :d_out.shape[-1]] + dqkv = dqkv[..., : d_out.shape[-1]] else: with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): if ctx.fp8: logger.debug("Running backward in FP8") - fp8_dtype_forward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False) + ctx.fp8_meta["recipe"], fprop_tensor=False + ) if ctx.fp8_meta["recipe"].fp8_mha: d_out_fp8 = d_out - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: d_out_fp8 = cast_to_fp8( d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ).view(d_out.shape) + ctx.fp8_meta["scaling_bwd"], + META_DO, + fp8_dtype_backward, + ).view(d_out.shape) dqkv_fp8, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, cu_seqlens, - qkv_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors, + ctx.max_seqlen, + cu_seqlens, + qkv_fp8, + out_fp8, + d_out_fp8, + fp8_dtype_forward, + fp8_dtype_backward, + aux_ctx_tensors, ctx.fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp - ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + fwd_scale_invs[META_QKV], # d_scale_qkv, + fwd_scale_invs[META_S], # d_scale_s, + fwd_scale_invs[META_O], # d_scale_o, + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp + ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + ctx.attn_scale, + ctx.dropout_p, + ctx.fast_zero_fill, + ctx.qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ) if ctx.fp8_meta["recipe"].fp8_mha: - dqkv = Float8Tensor(data=dqkv_fp8, + dqkv = Float8Tensor( + data=dqkv_fp8, fp8_meta=ctx.fp8_meta, fp8_meta_forward=False, fp8_meta_index=META_DQKV, fp8_dtype=fp8_dtype_backward, dtype=d_out_f8tensor.dtype, - ) + ) else: - dqkv_c_fp8 = dqkv_fp8.view(-1, - dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]) - dqkv = cast_from_fp8(dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape) + dqkv_c_fp8 = dqkv_fp8.view( + -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1] + ) + dqkv = cast_from_fp8( + dqkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + ctx.qkv_dtype, + ).view(dqkv_fp8.shape) else: - logger.debug("Running backward in %s",qkv.dtype) + logger.debug("Running backward in %s", qkv.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(qkv.dtype) dqkv, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, cu_seqlens, qkv, out, d_out, - ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors, + ctx.max_seqlen, + cu_seqlens, + qkv, + out, + d_out, + ctx.qkv_dtype, + ctx.qkv_dtype, + aux_ctx_tensors, ctx.fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - None, None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ctx.attn_scale, + ctx.dropout_p, + ctx.fast_zero_fill, + ctx.qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None, None, dqkv, None, None, None, - None, None, None, None, None, None, - None, None, None, None, None, None) + return ( + None, + None, + None, + None, + None, + None, + None, + dqkv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, None, dqkv, None, rest[0], None, - None, None, None, None, None, None, - None, None, None, None, None, None) + return ( + None, + None, + None, + None, + None, + None, + None, + dqkv, + None, + rest[0], + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) class FusedAttnFunc_kvpacked(torch.autograd.Function): """Function for FusedAttention with packed KV input""" @staticmethod - def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, - qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, - use_FAv2_bwd, fp8, fp8_meta): + def forward( + ctx, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + q, + kv, + qkv_dtype, + attn_bias, + attn_scale, + dropout_p, + fast_zero_fill, + qkv_layout, + attn_bias_type, + attn_mask_type, + rng_gen, + fused_attention_backend, + use_FAv2_bwd, + fp8, + fp8_meta, + ): logger = logging.getLogger("FusedAttnFunc_kvpacked") if fp8: logger.debug("Running forward in FP8") if fp8_meta["recipe"].fp8_mha: - assert (isinstance(q, Float8Tensor) - and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA." + assert isinstance(q, Float8Tensor) and isinstance( + kv, Float8Tensor + ), "q/kv must be Float8Tensors for FP8 MHA." fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -2524,31 +3116,50 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): q_fp8, kv_fp8 = q._data, kv._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split('_')) - assert (qkv_group == 2 - ), f"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, \ - but found {qkv_layout}." - q_fp8 = cast_to_fp8(q, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(q.shape) + qkv_group = len(qkv_layout.split("_")) + assert qkv_group == 2, ( + "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " + f" but found {qkv_layout}." + ) + q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view( + q.shape + ) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8(kv_c, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(kv.shape) + kv_fp8 = cast_to_fp8( + kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward + ).view(kv.shape) out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_fp8, + kv_fp8, + fp8_dtype_forward, + fused_attention_backend, + attn_bias, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_O], fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].amax_history[0][META_O], - attn_scale, dropout_p, fast_zero_fill, qkv_layout, - attn_bias_type, attn_mask_type, rng_gen) + attn_scale, + dropout_p, + fast_zero_fill, + qkv_layout, + attn_bias_type, + attn_mask_type, + rng_gen, + ) if fp8_meta["recipe"].fp8_mha: - out_ret = Float8Tensor(data=out_fp8, + out_ret = Float8Tensor( + data=out_fp8, fp8_meta=fp8_meta, fp8_meta_forward=True, fp8_meta_index=META_O, @@ -2558,41 +3169,85 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): else: out_ret = cast_from_fp8( out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) out_save = out_ret if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = cast_from_fp8(q._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) + q = cast_from_fp8( + q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype] + ).view(q.shape) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv = cast_from_fp8(kv_c._data, + kv = cast_from_fp8( + kv_c._data, fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape) + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) out_save = cast_from_fp8( out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) - fp8_tensors = (q_fp8, kv_fp8, out_fp8, + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) + fp8_tensors = ( + q_fp8, + kv_fp8, + out_fp8, fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone()) + fp8_meta["scaling_fwd"].scale_inv.clone(), + ) else: - logger.debug("Running forward in %s",q.dtype) + logger.debug("Running forward in %s", q.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, kv, qkv_dtype, fused_attention_backend, attn_bias, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - None, None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + kv, + qkv_dtype, + fused_attention_backend, + attn_bias, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + None, + None, + None, + None, + None, + None, + attn_scale, + dropout_p, + fast_zero_fill, + qkv_layout, + attn_bias_type, + attn_mask_type, + rng_gen, + ) out_save = out_ret fp8_tensors = (None, None, None, None, None) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) - ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - *fp8_tensors, *aux_ctx_tensors) + ctx.save_for_backward( + *qkvo_tensors, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + *fp8_tensors, + *aux_ctx_tensors, + ) ctx.fp8_meta = fp8_meta ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2603,8 +3258,9 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type - ctx.fused_attention_backend = \ + ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] + ) ctx.use_FAv2_bwd = use_FAv2_bwd return out_ret @@ -2613,16 +3269,30 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): def backward(ctx, d_out): logger = logging.getLogger("FusedAttnFunc_kvpacked") if ctx.fp8_meta["recipe"].fp8_mha: - assert (isinstance(d_out, Float8Tensor) - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." + assert isinstance( + d_out, Float8Tensor + ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." d_out_f8tensor = d_out d_out = d_out._data d_out = d_out.contiguous() - (q, kv, out, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - q_fp8, kv_fp8, out_fp8, - fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors + ( + q, + kv, + out, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + q_fp8, + kv_fp8, + out_fp8, + fwd_scales, + fwd_scale_invs, + *aux_ctx_tensors, + ) = ctx.saved_tensors if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -2630,167 +3300,329 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): dq = torch.empty_like(q) dkv = torch.empty_like(kv) maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - d_out, q, k, v, out = [maybe_contiguous(x) - for x in (d_out, q, kv[:,0], kv[:,1], out)] + d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)] flash_attn_cuda_bwd( - d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1], - cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv, - ctx.dropout_p, ctx.attn_scale, False, - "causal" in ctx.attn_mask_type, None, rng_state + d_out, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_kv, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ctx.dropout_p, + ctx.attn_scale, + False, + "causal" in ctx.attn_mask_type, + None, + rng_state, ) - dq = dq[..., :d_out.shape[-1]] - dkv = dkv[..., :d_out.shape[-1]] + dq = dq[..., : d_out.shape[-1]] + dkv = dkv[..., : d_out.shape[-1]] else: with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): if ctx.fp8: logger.debug("Running backward in FP8") - fp8_dtype_forward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False) + ctx.fp8_meta["recipe"], fprop_tensor=False + ) if ctx.fp8_meta["recipe"].fp8_mha: d_out_fp8 = d_out - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: d_out_fp8 = cast_to_fp8( d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ).view(d_out.shape) + ctx.fp8_meta["scaling_bwd"], + META_DO, + fp8_dtype_backward, + ).view(d_out.shape) dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q_fp8, kv_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_fp8, + kv_fp8, + out_fp8, + d_out_fp8, + fp8_dtype_forward, + fp8_dtype_backward, + aux_ctx_tensors, ctx.fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp - ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + fwd_scale_invs[META_QKV], # d_scale_qkv, + fwd_scale_invs[META_S], # d_scale_s, + fwd_scale_invs[META_O], # d_scale_o, + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp + ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + ctx.attn_scale, + ctx.dropout_p, + ctx.fast_zero_fill, + ctx.qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ) if ctx.fp8_meta["recipe"].fp8_mha: - dq = Float8Tensor(data=dq_fp8, + dq = Float8Tensor( + data=dq_fp8, fp8_meta=ctx.fp8_meta, fp8_meta_forward=False, fp8_meta_index=META_DQKV, fp8_dtype=fp8_dtype_backward, dtype=d_out_f8tensor.dtype, - ) - dkv = Float8Tensor(data=dkv_fp8, + ) + dkv = Float8Tensor( + data=dkv_fp8, fp8_meta=ctx.fp8_meta, fp8_meta_forward=False, fp8_meta_index=META_DQKV, fp8_dtype=fp8_dtype_backward, dtype=d_out_f8tensor.dtype, - ) + ) else: dq = cast_from_fp8( dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) - dkv_c_fp8 = dkv_fp8.view(-1, - dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]) - dkv = cast_from_fp8(dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape) + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + ctx.qkv_dtype, + ).view(dq_fp8.shape) + dkv_c_fp8 = dkv_fp8.view( + -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] + ) + dkv = cast_from_fp8( + dkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + ctx.qkv_dtype, + ).view(dkv_fp8.shape) else: - logger.debug("Running backward in %s",q.dtype) + logger.debug("Running backward in %s", q.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(q.dtype) dq, dkv, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, kv, out, d_out, - ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + kv, + out, + d_out, + ctx.qkv_dtype, + ctx.qkv_dtype, + aux_ctx_tensors, ctx.fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - None, None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ctx.attn_scale, + ctx.dropout_p, + ctx.fast_zero_fill, + ctx.qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None, None, None, None, dq, dkv, None, None, None, - None, None, None, None, None, None, - None, None, None, None, None, None) + return ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + dq, + dkv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, None, None, None, dq, dkv, None, rest[0], None, - None, None, None, None, None, None, - None, None, None, None, None, None) + return ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + dq, + dkv, + None, + rest[0], + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + class FusedAttnFunc(torch.autograd.Function): """Function for FusedAttention with separate Q, K, V tensors""" @staticmethod - def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, - qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, - use_FAv2_bwd, fp8, fp8_meta): + def forward( + ctx, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + q, + k, + v, + qkv_dtype, + attn_bias, + attn_scale, + dropout_p, + fast_zero_fill, + qkv_layout, + attn_bias_type, + attn_mask_type, + rng_gen, + fused_attention_backend, + use_FAv2_bwd, + fp8, + fp8_meta, + ): logger = logging.getLogger("FusedAttnFunc") if fp8: logger.debug("Running forward in FP8") fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if fp8_meta["recipe"].fp8_mha: - assert (isinstance(q, Float8Tensor) + assert ( + isinstance(q, Float8Tensor) and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor)), "q/k/v must be Float8Tensors for FP8 MHA." + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA." fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split('_')) + qkv_group = len(qkv_layout.split("_")) if qkv_group == 1: - dim = qkv_layout.find('3') - qkv = _combine_tensors([q,k,v], dim) + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8(qkv_c, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(qkv.shape) - q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1,1,1]) + qkv_fp8 = cast_to_fp8( + qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward + ).view(qkv.shape) + q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1]) q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]] if qkv_group == 2: - q_fp8 = cast_to_fp8(q, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(q.shape) - dim = qkv_layout.split('_')[1].find('2') - kv = _combine_tensors([k,v], dim) + q_fp8 = cast_to_fp8( + q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward + ).view(q.shape) + dim = qkv_layout.split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8(kv_c, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(kv.shape) - k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1,1]) + kv_fp8 = cast_to_fp8( + kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward + ).view(kv.shape) + k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1]) k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]] if qkv_group == 3: - q_fp8 = cast_to_fp8(q, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(q.shape) - k_fp8 = cast_to_fp8(k, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(k.shape) - v_fp8 = cast_to_fp8(v, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(v.shape) + q_fp8 = cast_to_fp8( + q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward + ).view(q.shape) + k_fp8 = cast_to_fp8( + k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward + ).view(k.shape) + v_fp8 = cast_to_fp8( + v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward + ).view(v.shape) out_fp8, aux_ctx_tensors = fused_attn_fwd( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_fp8, + k_fp8, + v_fp8, + fp8_dtype_forward, + fused_attention_backend, + attn_bias, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_O], fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].amax_history[0][META_O], - attn_scale, dropout_p, fast_zero_fill, qkv_layout, - attn_bias_type, attn_mask_type, rng_gen) + attn_scale, + dropout_p, + fast_zero_fill, + qkv_layout, + attn_bias_type, + attn_mask_type, + rng_gen, + ) if fp8_meta["recipe"].fp8_mha: - out_ret = Float8Tensor(data=out_fp8, + out_ret = Float8Tensor( + data=out_fp8, fp8_meta=fp8_meta, fp8_meta_forward=True, fp8_meta_index=META_O, @@ -2800,77 +3632,144 @@ class FusedAttnFunc(torch.autograd.Function): else: out_ret = cast_from_fp8( out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) out_save = out_ret if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split('_')) + qkv_group = len(qkv_layout.split("_")) if qkv_group == 1: - dim = qkv_layout.find('3') - qkv = _combine_tensors([q,k,v], dim) + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = cast_from_fp8(qkv_c._data, + qkv_no_fp8 = cast_from_fp8( + qkv_c._data, fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape) - q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1,1,1]) + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) q, k, v = [x.squeeze(dim) for x in [q, k, v]] if qkv_group == 2: - q = cast_from_fp8(q._data, + q = cast_from_fp8( + q._data, fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) - dim = qkv_layout.split('_')[1].find('2') - kv = _combine_tensors([k,v], dim) + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + dim = qkv_layout.split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = cast_from_fp8(kv_c._data, + kv_no_fp8 = cast_from_fp8( + kv_c._data, fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape) - k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1,1]) + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) k, v = [x.squeeze(dim) for x in [k, v]] if qkv_group == 3: - q = cast_from_fp8(q._data, + q = cast_from_fp8( + q._data, fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) - k = cast_from_fp8(k._data, + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + k = cast_from_fp8( + k._data, fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[k.dtype]).view(k.shape) - v = cast_from_fp8(v._data, + META_QKV, + fp8_dtype_forward, + TE_DType[k.dtype], + ).view(k.shape) + v = cast_from_fp8( + v._data, fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[v.dtype]).view(v.shape) + META_QKV, + fp8_dtype_forward, + TE_DType[v.dtype], + ).view(v.shape) out_save = cast_from_fp8( out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) - - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8, + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) + + fp8_tensors = ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone()) + fp8_meta["scaling_fwd"].scale_inv.clone(), + ) else: - logger.debug("Running forward in %s",q.dtype) + logger.debug("Running forward in %s", q.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, k, v, qkv_dtype, fused_attention_backend, attn_bias, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - None, None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + qkv_dtype, + fused_attention_backend, + attn_bias, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + None, + None, + None, + None, + None, + None, + attn_scale, + dropout_p, + fast_zero_fill, + qkv_layout, + attn_bias_type, + attn_mask_type, + rng_gen, + ) out_save = out_ret fp8_tensors = (None, None, None, None, None, None) from .cpu_offload import CPUOffloadEnabled + if CPUOffloadEnabled: tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv] - qkv_layout = 'sbhd_sbhd_sbhd' + qkv_layout = "sbhd_sbhd_sbhd" for tensor in tensor_list: if tensor is not None: tensor.activation_offloading = True ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) - ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - *fp8_tensors, *aux_ctx_tensors) + ctx.save_for_backward( + *qkvo_tensors, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + *fp8_tensors, + *aux_ctx_tensors, + ) ctx.fp8_meta = fp8_meta ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2881,8 +3780,9 @@ class FusedAttnFunc(torch.autograd.Function): ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type - ctx.fused_attention_backend = \ + ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] + ) ctx.use_FAv2_bwd = use_FAv2_bwd return out_ret @@ -2891,16 +3791,32 @@ class FusedAttnFunc(torch.autograd.Function): def backward(ctx, d_out): logger = logging.getLogger("FusedAttnFunc") if ctx.fp8_meta["recipe"].fp8_mha: - assert (isinstance(d_out, Float8Tensor) - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." + assert isinstance( + d_out, Float8Tensor + ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." d_out_f8tensor = d_out d_out = d_out._data d_out = d_out.contiguous() - (q, k, v, out, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - q_fp8, k_fp8, v_fp8, out_fp8, - fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors + ( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + fwd_scales, + fwd_scale_invs, + *aux_ctx_tensors, + ) = ctx.saved_tensors if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -2909,136 +3825,271 @@ class FusedAttnFunc(torch.autograd.Function): dk = torch.empty_like(k) dv = torch.empty_like(v) maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - d_out, q, k, v, out = [maybe_contiguous(x) - for x in (d_out, q, k, v, out)] + d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)] flash_attn_cuda_bwd( - d_out, q, k, v, out, softmax_lse, dq, dk, dv, - cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv, - ctx.dropout_p, ctx.attn_scale, False, - "causal" in ctx.attn_mask_type, None, rng_state + d_out, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_kv, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ctx.dropout_p, + ctx.attn_scale, + False, + "causal" in ctx.attn_mask_type, + None, + rng_state, ) - dq = dq[..., :d_out.shape[-1]] - dk = dk[..., :d_out.shape[-1]] - dv = dv[..., :d_out.shape[-1]] + dq = dq[..., : d_out.shape[-1]] + dk = dk[..., : d_out.shape[-1]] + dv = dv[..., : d_out.shape[-1]] else: with torch.cuda.nvtx.range("_FusedAttn"): if ctx.fp8: logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False) + ctx.fp8_meta["recipe"], fprop_tensor=False + ) if ctx.fp8_meta["recipe"].fp8_mha: d_out_fp8 = d_out - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: d_out_fp8 = cast_to_fp8( d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ).view(d_out.shape) + ctx.fp8_meta["scaling_bwd"], + META_DO, + fp8_dtype_backward, + ).view(d_out.shape) dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + d_out_fp8, + fp8_dtype_forward, + fp8_dtype_backward, + aux_ctx_tensors, ctx.fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp - ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + fwd_scale_invs[META_QKV], # d_scale_qkv, + fwd_scale_invs[META_S], # d_scale_s, + fwd_scale_invs[META_O], # d_scale_o, + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp + ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + ctx.attn_scale, + ctx.dropout_p, + ctx.fast_zero_fill, + ctx.qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ) if ctx.fp8_meta["recipe"].fp8_mha: - dq = Float8Tensor(data=dq_fp8, + dq = Float8Tensor( + data=dq_fp8, fp8_meta=ctx.fp8_meta, fp8_meta_forward=False, fp8_meta_index=META_DQKV, fp8_dtype=fp8_dtype_backward, dtype=d_out_f8tensor.dtype, - ) - dk = Float8Tensor(data=dk_fp8, + ) + dk = Float8Tensor( + data=dk_fp8, fp8_meta=ctx.fp8_meta, fp8_meta_forward=False, fp8_meta_index=META_DQKV, fp8_dtype=fp8_dtype_backward, dtype=d_out_f8tensor.dtype, - ) - dv = Float8Tensor(data=dv_fp8, + ) + dv = Float8Tensor( + data=dv_fp8, fp8_meta=ctx.fp8_meta, fp8_meta_forward=False, fp8_meta_index=META_DQKV, fp8_dtype=fp8_dtype_backward, dtype=d_out_f8tensor.dtype, - ) + ) else: - qkv_group = len(ctx.qkv_layout.split('_')) + qkv_group = len(ctx.qkv_layout.split("_")) if qkv_group == 1: - dim = ctx.qkv_layout.find('3') - dqkv_fp8 = _combine_tensors([dq_fp8,dk_fp8,dv_fp8], dim) - dqkv_c_fp8 = dqkv_fp8.view(-1, - dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]) - dqkv = cast_from_fp8(dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape) - dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1,1,1]) + dim = ctx.qkv_layout.find("3") + dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim) + dqkv_c_fp8 = dqkv_fp8.view( + -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1] + ) + dqkv = cast_from_fp8( + dqkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + ctx.qkv_dtype, + ).view(dqkv_fp8.shape) + dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1]) dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]] if qkv_group == 2: dq = cast_from_fp8( dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) - dim = ctx.qkv_layout.split('_')[1].find('2') - dkv_fp8 = _combine_tensors([dk_fp8,dv_fp8], dim) - dkv_c_fp8 = dkv_fp8.view(-1, - dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]) - dkv = cast_from_fp8(dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape) - dk, dv = _SplitAlongDim.apply(dkv, dim, [1,1]) + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + ctx.qkv_dtype, + ).view(dq_fp8.shape) + dim = ctx.qkv_layout.split("_")[1].find("2") + dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim) + dkv_c_fp8 = dkv_fp8.view( + -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] + ) + dkv = cast_from_fp8( + dkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + ctx.qkv_dtype, + ).view(dkv_fp8.shape) + dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1]) dk, dv = [x.squeeze(dim) for x in [dk, dv]] if qkv_group == 3: dq = cast_from_fp8( dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + ctx.qkv_dtype, + ).view(dq_fp8.shape) dk = cast_from_fp8( dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dk_fp8.shape) + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + ctx.qkv_dtype, + ).view(dk_fp8.shape) dv = cast_from_fp8( dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape) + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + ctx.qkv_dtype, + ).view(dv_fp8.shape) else: - logger.debug("Running backward in %s",q.dtype) + logger.debug("Running backward in %s", q.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(q.dtype) dq, dk, dv, *rest = fused_attn_bwd( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, k, v, out, d_out, - ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + out, + d_out, + ctx.qkv_dtype, + ctx.qkv_dtype, + aux_ctx_tensors, ctx.fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - None, None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ctx.attn_scale, + ctx.dropout_p, + ctx.fast_zero_fill, + ctx.qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None, - None, None, None, dq, dk, dv, None, None, None, - None, None, None, None, None, None, - None, None, None, None, None, None) + return ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, - None, None, None, dq, dk, dv, None, rest[0], None, - None, None, None, None, None, None, - None, None, None, None, None, None) + return ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + dq, + dk, + dv, + None, + rest[0], + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) class FusedAttention(TransformerEngineBaseModule): @@ -3085,8 +4136,9 @@ class FusedAttention(TransformerEngineBaseModule): self.attention_dropout = attention_dropout self.attention_dropout_ctx = attention_dropout_ctx self.attention_type = attention_type - self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1" - and get_device_compute_capability() == (9, 0)) + self.use_FAv2_bwd = os.getenv( + "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0" + ) == "1" and get_device_compute_capability() == (9, 0) self.layer_number = 1 if layer_number is None else layer_number if deterministic: # workspace optimization path is deterministic @@ -3104,15 +4156,16 @@ class FusedAttention(TransformerEngineBaseModule): if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ Temporarily remove fused_attention._extra_state as a missing key when loading older TransformerEngine checkpoints. Will phase out this hook in TransformerEngine 2.0. """ for key in incompatible_keys.missing_keys: - if 'fused_attention._extra_state' in key: + if "fused_attention._extra_state" in key: incompatible_keys.missing_keys.remove(key) + self.register_load_state_dict_post_hook(remove_extra_states_check) def get_fp8_weights_scratchpad( @@ -3138,8 +4191,7 @@ class FusedAttention(TransformerEngineBaseModule): max_seqlen_kv: Optional[int] = None, attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, - fused_attention_backend: - tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, + fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, @@ -3149,33 +4201,39 @@ class FusedAttention(TransformerEngineBaseModule): is_first_microbatch: Optional[bool] = None, ) -> torch.Tensor: """fused attention fprop""" - assert (fused_attention_backend - != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend - ), 'No fused attention backend supports this input combination!' + assert ( + fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend + ), "No fused attention backend supports this input combination!" assert ( (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) - ), 'FusedAttention only supports FP16 and BF16 data types.' + ), "FusedAttention only supports FP16 and BF16 data types." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda - ), 'FusedAttention only supports CUDA tensors.' + ), "FusedAttention only supports CUDA tensors." assert ( qkv_layout in QKVLayouts - ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" + ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) - qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format in ['sbhd', 'bshd']: - if qkv_format == 'sbhd': + if qkv_format in ["sbhd", "bshd"]: + if qkv_format == "sbhd": batch_size, max_seqlen_q, max_seqlen_kv = ( - query_layer.shape[1], query_layer.shape[0], key_layer.shape[0]) - if qkv_format == 'bshd': + query_layer.shape[1], + query_layer.shape[0], + key_layer.shape[0], + ) + if qkv_format == "bshd": batch_size, max_seqlen_q, max_seqlen_kv = ( - query_layer.shape[0], query_layer.shape[1], key_layer.shape[1]) - if 'padding' in attn_mask_type: + query_layer.shape[0], + query_layer.shape[1], + key_layer.shape[1], + ) + if "padding" in attn_mask_type: assert not context_parallel, "Padding mask not supported with context parallelism!" if cu_seqlens_q is None or cu_seqlens_kv is None: @@ -3202,61 +4260,75 @@ class FusedAttention(TransformerEngineBaseModule): max_seqlen_kv, key_layer.device, ) - if qkv_format == 'thd': - assert (max_seqlen_q is not None + if qkv_format == "thd": + assert ( + max_seqlen_q is not None and max_seqlen_kv is not None and cu_seqlens_q is not None and cu_seqlens_kv is not None - ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" - if (seq_offsets_q is None + ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" + if ( + seq_offsets_q is None or seq_offsets_k is None or seq_offsets_v is None or seq_offsets_o is None - or context_parallel): - qkv_group = ''.join([x for x in qkv_layout if x not in 'bst']) - qkv_group = 'hd_hd_hd' if context_parallel else qkv_group + or context_parallel + ): + qkv_group = "".join([x for x in qkv_layout if x not in "bst"]) + qkv_group = "hd_hd_hd" if context_parallel else qkv_group num_heads = query_layer.shape[-2] num_gqa_groups = key_layer.shape[-2] head_dim = query_layer.shape[-1] seq_offsets_o = num_heads * head_dim * cu_seqlens_q - if qkv_group == 'hd_hd_hd': + if qkv_group == "hd_hd_hd": seq_offsets_q = num_heads * head_dim * cu_seqlens_q seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv - if qkv_group in ['3hd', 'h3d']: + if qkv_group in ["3hd", "h3d"]: seq_offsets_q = num_heads * head_dim * 3 * cu_seqlens_q seq_offsets_k = num_heads * head_dim * 3 * cu_seqlens_q seq_offsets_v = num_heads * head_dim * 3 * cu_seqlens_q - if qkv_group in ['hd_2hd', 'hd_h2d']: + if qkv_group in ["hd_2hd", "hd_h2d"]: seq_offsets_q = num_heads * head_dim * cu_seqlens_q seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv qkv_dtype = TE_DType[query_layer.dtype] - use_FAv2_bwd = (self.use_FAv2_bwd - and (core_attention_bias_type == "no_bias") - and (fused_attention_backend - == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)) + use_FAv2_bwd = ( + self.use_FAv2_bwd + and (core_attention_bias_type == "no_bias") + and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen) + ) if context_parallel: - assert (fused_attention_backend - == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen - ), f"{fused_attention_backend} does not work with context parallelism!" assert ( - core_attention_bias_type not in ["alibi"] - ), f"{core_attention_bias_type} is not supported with context parallelism!" - query_layer, key_layer, value_layer = [x.contiguous() - for x in (query_layer, key_layer, value_layer)] + fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + ), f"{fused_attention_backend} does not work with context parallelism!" + assert core_attention_bias_type not in [ + "alibi" + ], f"{core_attention_bias_type} is not supported with context parallelism!" + query_layer, key_layer, value_layer = [ + x.contiguous() for x in (query_layer, key_layer, value_layer) + ] with self.attention_dropout_ctx(): output = attn_forward_func_with_cp( self.training, - query_layer, key_layer, value_layer, - cu_seqlens_q, cu_seqlens_kv, - max_seqlen_q, max_seqlen_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, self.attention_dropout if self.training else 0.0, - cp_group, cp_global_ranks, cp_stream, + cp_group, + cp_global_ranks, + cp_stream, softmax_scale=self.softmax_scale, qkv_format=qkv_format, attn_mask_type=attn_mask_type, @@ -3265,10 +4337,9 @@ class FusedAttention(TransformerEngineBaseModule): use_fused_attention=True, ) else: - with self.prepare_forward(query_layer, - is_first_microbatch, - num_gemms=3, - allow_non_contiguous=True) as query_layer: + with self.prepare_forward( + query_layer, is_first_microbatch, num_gemms=3, allow_non_contiguous=True + ) as query_layer: with self.attention_dropout_ctx(): forced_fp8_dpa = "" if self.fp8_meta["recipe"].fp8_mha: @@ -3282,13 +4353,21 @@ class FusedAttention(TransformerEngineBaseModule): self.fp8_meta["recipe"].fp8_mha, self.fp8_meta["recipe"].fp8_dpa, forced_fp8_dpa, - int(os.getenv("NVTE_FP8_DPA_BWD", "1"))) + int(os.getenv("NVTE_FP8_DPA_BWD", "1")), + ) output = FusedAttnFunc.apply( self.training, - max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - query_layer, key_layer, value_layer, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + query_layer, + key_layer, + value_layer, qkv_dtype, core_attention_bias, self.softmax_scale, @@ -3297,7 +4376,7 @@ class FusedAttention(TransformerEngineBaseModule): qkv_layout, core_attention_bias_type, attn_mask_type, - None, # rng_gen + None, # rng_gen fused_attention_backend, use_FAv2_bwd, self.fp8 and self.fp8_meta["recipe"].fp8_dpa, @@ -3424,7 +4503,7 @@ class DotProductAttention(torch.nn.Module): self.logger = logging.getLogger("DotProductAttention") self.qkv_format = qkv_format - attn_mask_type = attn_mask_type.replace(",","_") + attn_mask_type = attn_mask_type.replace(",", "_") if attn_mask_type == "causal_padding": attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type @@ -3441,13 +4520,12 @@ class DotProductAttention(torch.nn.Module): self.hidden_size_per_attention_head = kv_channels - self.num_gqa_groups = ( - num_attention_heads if num_gqa_groups is None else num_gqa_groups - ) + self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) - assert (num_attention_heads % self.num_gqa_groups == 0 - ), "The number of attention heads must be divisible by the number of GQA groups!" + assert ( + num_attention_heads % self.num_gqa_groups == 0 + ), "The number of attention heads must be divisible by the number of GQA groups!" self.rng_states_tracker = None if sequence_parallel or get_rng_state_tracker is None: @@ -3461,13 +4539,14 @@ class DotProductAttention(torch.nn.Module): softmax_scale = 1.0 / math.sqrt(kv_channels) self.device_compute_capability = get_device_compute_capability() - self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) \ - or torch.are_deterministic_algorithms_enabled() - - self.use_flash_attention = ( - int(os.getenv("NVTE_FLASH_ATTN", "1")) - and self.device_compute_capability >= (8, 0) + self.deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() ) + + self.use_flash_attention = int( + os.getenv("NVTE_FLASH_ATTN", "1") + ) and self.device_compute_capability >= (8, 0) if int(os.getenv("NVTE_FLASH_ATTN", "1")) == 0: self.logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") if self.device_compute_capability < (8, 0): @@ -3481,18 +4560,15 @@ class DotProductAttention(torch.nn.Module): " please install FlashAttention version >=2.4.1." ) - self.use_fused_attention = ( - int(os.getenv("NVTE_FUSED_ATTN", "1")) - and self.device_compute_capability >= (8, 0) - ) + self.use_fused_attention = int( + os.getenv("NVTE_FUSED_ATTN", "1") + ) and self.device_compute_capability >= (8, 0) if int(os.getenv("NVTE_FUSED_ATTN", "1")) == 0: self.logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") if self.device_compute_capability < (8, 0): self.logger.debug("Disabling FusedAttention for compute capability < sm80") - assert ( - attention_type in AttnTypes - ), f"attention_type {attention_type} not supported" + assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" self.attention_type = attention_type self.attention_dropout = attention_dropout @@ -3503,23 +4579,28 @@ class DotProductAttention(torch.nn.Module): } if self.use_flash_attention: - self.flash_attention = FlashAttention(softmax_scale, - attention_type=attention_type, - layer_number=layer_number, - deterministic=self.deterministic, - **attn_kwargs) + self.flash_attention = FlashAttention( + softmax_scale, + attention_type=attention_type, + layer_number=layer_number, + deterministic=self.deterministic, + **attn_kwargs, + ) # Instantiating three types since use of flash-attn and FusedAttention # might be ruled out due to forward inputs. if self.use_fused_attention: - self.fused_attention = FusedAttention(softmax_scale, - attention_type=attention_type, - layer_number=layer_number, - deterministic=self.deterministic, - **attn_kwargs) + self.fused_attention = FusedAttention( + softmax_scale, + attention_type=attention_type, + layer_number=layer_number, + deterministic=self.deterministic, + **attn_kwargs, + ) self.unfused_attention = UnfusedDotProductAttention( - softmax_scale, **attn_kwargs, layer_number=layer_number) + softmax_scale, **attn_kwargs, layer_number=layer_number + ) def _checkpointed_attention_forward( self, @@ -3727,29 +4808,30 @@ class DotProductAttention(torch.nn.Module): assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda - ), 'DotProductAttention only supports CUDA tensors.' + ), "DotProductAttention only supports CUDA tensors." - assert (key_layer.shape == value_layer.shape - ), "Keys and values must have the same shape!" + assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" if attn_mask_type is not None: window_size = check_set_window_size(attn_mask_type, window_size) if attn_mask_type is None: attn_mask_type = self.attn_mask_type else: - attn_mask_type = attn_mask_type.replace(",","_") + attn_mask_type = attn_mask_type.replace(",", "_") if attn_mask_type == "causal_padding": attn_mask_type = "padding_causal" - assert (attn_mask_type in AttnMaskTypes - ), f"Attention mask type {attn_mask_type} is not supported!" - if qkv_format == 'thd': - assert ('padding' in attn_mask_type - ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" + assert ( + attn_mask_type in AttnMaskTypes + ), f"Attention mask type {attn_mask_type} is not supported!" + if qkv_format == "thd": + assert ( + "padding" in attn_mask_type + ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" if self.rng_states_tracker is not None and is_graph_capturing(): - assert ( - isinstance(self.rng_states_tracker, CudaRNGStatesTracker) + assert isinstance( + self.rng_states_tracker, CudaRNGStatesTracker ), "Unsupported RNG states tracker." assert ( graph_safe_rng_available() @@ -3768,7 +4850,9 @@ class DotProductAttention(torch.nn.Module): key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) - (inference_key_memory, inference_value_memory, + ( + inference_key_memory, + inference_value_memory, ) = inference_params.key_value_memory_dict[self.layer_number] batch_start = inference_params.batch_size_offset @@ -3780,10 +4864,12 @@ class DotProductAttention(torch.nn.Module): assert sequence_end <= inference_key_memory.size(0) # Copy keys and values into KV-cache - inference_key_memory[ - sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer - inference_value_memory[ - sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer + inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( + key_layer + ) + inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( + value_layer + ) key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] @@ -3794,24 +4880,31 @@ class DotProductAttention(torch.nn.Module): key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition + assert ( + key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" - assert (qkv_format in ['sbhd', 'bshd', 'thd'] - ), "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" - - if qkv_format == 'thd': - assert (all(len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)) - ), "Queries, keys and values must be 3D tensors when qkv_format = thd!" - assert (cu_seqlens_q is not None and cu_seqlens_kv is not None - ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" - assert (cu_seqlens_q.shape == cu_seqlens_kv.shape + ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" + assert qkv_format in [ + "sbhd", + "bshd", + "thd", + ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" + + if qkv_format == "thd": + assert all( + len(x.shape) == 3 for x in (query_layer, key_layer, value_layer) + ), "Queries, keys and values must be 3D tensors when qkv_format = thd!" + assert ( + cu_seqlens_q is not None and cu_seqlens_kv is not None + ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" + assert ( + cu_seqlens_q.shape == cu_seqlens_kv.shape and len(cu_seqlens_q.shape) == 1 and len(cu_seqlens_kv.shape) == 1 - ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!" - assert (cu_seqlens_q.dtype == torch.int32 - and cu_seqlens_kv.dtype == torch.int32 - ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" + ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!" + assert ( + cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 + ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" if max_seqlen_q is None: seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item()))) @@ -3819,32 +4912,39 @@ class DotProductAttention(torch.nn.Module): seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item()))) - if qkv_format in ['sbhd', 'bshd']: - assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)) - ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" - if qkv_format == 'sbhd': + if qkv_format in ["sbhd", "bshd"]: + assert all( + len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) + ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" + if qkv_format == "sbhd": max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0]) - if qkv_format == 'bshd': + if qkv_format == "bshd": max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) if cu_seqlens_q is not None: seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - assert (all(seqlens_q <= max_seqlen_q) - ), """Sequence lengths indicated by cu_seqlens_q must be no greater than + assert all( + seqlens_q <= max_seqlen_q + ), """Sequence lengths indicated by cu_seqlens_q must be no greater than the sequence dimention in 'query_layer'!""" if cu_seqlens_kv is not None: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - assert (all(seqlens_kv <= max_seqlen_kv) - ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than + assert all( + seqlens_kv <= max_seqlen_kv + ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than the sequence dimention in 'key_layer' and 'value_layer'!""" - if (isinstance(query_layer, Float8Tensor) + if ( + isinstance(query_layer, Float8Tensor) and isinstance(key_layer, Float8Tensor) - and isinstance(value_layer, Float8Tensor)): + and isinstance(value_layer, Float8Tensor) + ): qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout( - query_layer._data, key_layer._data, value_layer._data, qkv_format = qkv_format) + query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format + ) else: qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout( - query_layer, key_layer, value_layer, qkv_format = qkv_format) + query_layer, key_layer, value_layer, qkv_format=qkv_format + ) # The priority for attention backends (subject to availability and clearing the filters) # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention. @@ -3856,7 +4956,7 @@ class DotProductAttention(torch.nn.Module): # certain asserts before executing the forward pass. # Filter: QKV layout. - if use_unfused_attention and qkv_format == 'thd': + if use_unfused_attention and qkv_format == "thd": self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd") use_unfused_attention = False @@ -3870,49 +4970,61 @@ class DotProductAttention(torch.nn.Module): use_fused_attention = False # Filter: Input type. - if (use_flash_attention - and (query_layer.dtype not in [torch.bfloat16, torch.float16] - or key_layer.dtype not in [torch.bfloat16, torch.float16] - or value_layer.dtype not in [torch.bfloat16, torch.float16] - or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])) + if use_flash_attention and ( + query_layer.dtype not in [torch.bfloat16, torch.float16] + or key_layer.dtype not in [torch.bfloat16, torch.float16] + or value_layer.dtype not in [torch.bfloat16, torch.float16] + or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]) ): self.logger.debug( "Disabling FlashAttention due to unsupported QKV data types. " "Supported: [torch.bfloat16, torch.float16]. " "Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.", - query_layer.dtype, key_layer.dtype, value_layer.dtype) + query_layer.dtype, + key_layer.dtype, + value_layer.dtype, + ) use_flash_attention = False - if (use_fused_attention - and (query_layer.dtype not in [torch.bfloat16, torch.float16] - or key_layer.dtype not in [torch.bfloat16, torch.float16] - or value_layer.dtype not in [torch.bfloat16, torch.float16]) + if use_fused_attention and ( + query_layer.dtype not in [torch.bfloat16, torch.float16] + or key_layer.dtype not in [torch.bfloat16, torch.float16] + or value_layer.dtype not in [torch.bfloat16, torch.float16] ): self.logger.debug( "Disabling FusedAttention due to unsupported QKV data types. " "Supported: [torch.bfloat16, torch.float16, Float8Tensor]. " "Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.", - query_layer.dtype, key_layer.dtype, value_layer.dtype) + query_layer.dtype, + key_layer.dtype, + value_layer.dtype, + ) use_fused_attention = False # Filter: Device and dimensions. # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90 # FAv2 requires head_dim % 8 == 0 - if (use_flash_attention - and (query_layer.shape[-1] > 256 - or query_layer.shape[-1] % 8 != 0 - or (query_layer.shape[-1] > 192 - and self.device_compute_capability not in ((8, 0), (9, 0))))): + if use_flash_attention and ( + query_layer.shape[-1] > 256 + or query_layer.shape[-1] % 8 != 0 + or ( + query_layer.shape[-1] > 192 + and self.device_compute_capability not in ((8, 0), (9, 0)) + ) + ): self.logger.debug( "Disabling FlashAttention due to unsupported head_dim. " "Supported: %%8 == 0, and <= 256; sm80/90 for >192. " "Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s", - query_layer.shape[-1], key_layer.shape[-1], - '.'.join([str(i) for i in self.device_compute_capability])) + query_layer.shape[-1], + key_layer.shape[-1], + ".".join([str(i) for i in self.device_compute_capability]), + ) use_flash_attention = False # Filter: cross attention + causal mask. # (in training mode) - if (use_flash_attention + if ( + use_flash_attention and inference_params is None and _flash_attn_2_1_plus and "causal" in attn_mask_type @@ -3925,8 +5037,9 @@ class DotProductAttention(torch.nn.Module): ) use_flash_attention = False - context_parallel = (self.cp_group is not None and \ - get_distributed_world_size(self.cp_group) != 1) + context_parallel = ( + self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1 + ) # Filter: sliding window attention. # UnfusedDotProductAttention can support SWA via arbitrary attention mask. @@ -3938,7 +5051,8 @@ class DotProductAttention(torch.nn.Module): if use_flash_attention: self.logger.debug( "Disabling FusedAttention as it requires flash-attn 2.3+ " - "and no context parallelism") + "and no context parallelism" + ) use_flash_attention = False # Filter: Attention mask type. @@ -3958,7 +5072,8 @@ class DotProductAttention(torch.nn.Module): self.logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False - if (use_unfused_attention + if ( + use_unfused_attention and inference_params is None and "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv @@ -3969,24 +5084,28 @@ class DotProductAttention(torch.nn.Module): # Filter: bias. global _alibi_cache if alibi_slopes is not None: - assert (core_attention_bias_type == "alibi" - ), "core_attention_bias_type must be alibi in order to use alibi_slopes!" + assert ( + core_attention_bias_type == "alibi" + ), "core_attention_bias_type must be alibi in order to use alibi_slopes!" if self.layer_number == 1: _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True if core_attention_bias_type == "alibi": - assert (core_attention_bias is None - ), "core_attention_bias must be None when core_attention_bias_type is alibi!" - if (_alibi_cache["_num_heads"] != query_layer.shape[-2] + assert ( + core_attention_bias is None + ), "core_attention_bias must be None when core_attention_bias_type is alibi!" + if ( + _alibi_cache["_num_heads"] != query_layer.shape[-2] or _alibi_cache["_max_seqlen_q"] != max_seqlen_q or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_alibi_slopes"] is None): + or _alibi_cache["_alibi_slopes"] is None + ): _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - if (use_flash_attention - and (core_attention_bias_type not in ["no_bias", "alibi"] - or core_attention_bias is not None)): + if use_flash_attention and ( + core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None + ): self.logger.debug("Disabling FlashAttention for pre/post_scale_bias") use_flash_attention = False @@ -3995,12 +5114,20 @@ class DotProductAttention(torch.nn.Module): if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None: fu_core_attention_bias_type = "post_scale_bias" _, fu_core_attention_bias = get_alibi( - query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes, - bias_dtype=query_layer.dtype) - if (use_fused_attention + query_layer.shape[-2], + max_seqlen_q, + max_seqlen_kv, + alibi_slopes=alibi_slopes, + bias_dtype=query_layer.dtype, + ) + if ( + use_fused_attention and fu_core_attention_bias_type == "post_scale_bias" - and (fu_core_attention_bias.shape[0] != 1 - or fu_core_attention_bias.shape[1] != query_layer.shape[-2])): + and ( + fu_core_attention_bias.shape[0] != 1 + or fu_core_attention_bias.shape[1] != query_layer.shape[-2] + ) + ): if fu_core_attention_bias.requires_grad: # remove this line when cuDNN adds bwd support for # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] @@ -4012,35 +5139,51 @@ class DotProductAttention(torch.nn.Module): if use_fused_attention: fused_attention_backend = tex.get_fused_attn_backend( - TE_DType[query_layer.dtype] - if not isinstance(query_layer, Float8Tensor) else query_layer._fp8_dtype, - TE_DType[key_layer.dtype] - if not isinstance(key_layer, Float8Tensor) else key_layer._fp8_dtype, + ( + TE_DType[query_layer.dtype] + if not isinstance(query_layer, Float8Tensor) + else query_layer._fp8_dtype + ), + ( + TE_DType[key_layer.dtype] + if not isinstance(key_layer, Float8Tensor) + else key_layer._fp8_dtype + ), QKVLayout[qkv_layout], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], self.attention_dropout, - query_layer.shape[-2], # num_attn_heads - key_layer.shape[-2], # num_gqa_groups + query_layer.shape[-2], # num_attn_heads + key_layer.shape[-2], # num_gqa_groups max_seqlen_q, max_seqlen_kv, - query_layer.shape[-1], # head_dim + query_layer.shape[-1], # head_dim ) # DPA does not support FP8; for FP8, use cpp_extensions modules directly - is_backend_avail = (fused_attention_backend in - [FusedAttnBackend["F16_max512_seqlen"], + is_backend_avail = fused_attention_backend in [ + FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"], - FusedAttnBackend["FP8"]]) - use_fused_attention = ( \ - use_fused_attention and is_backend_avail and \ - (not context_parallel or \ - fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"])) - if (fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] + FusedAttnBackend["FP8"], + ] + use_fused_attention = ( + use_fused_attention + and is_backend_avail + and ( + not context_parallel + or fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + ) + ) + if ( + fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] and fu_core_attention_bias_type == "post_scale_bias" - and (fu_core_attention_bias.shape[0] != 1 - or fu_core_attention_bias.shape[1] != query_layer.shape[-2])): + and ( + fu_core_attention_bias.shape[0] != 1 + or fu_core_attention_bias.shape[1] != query_layer.shape[-2] + ) + ): self.logger.debug( - "Disabling FusedAttention as no backend supports the provided input") + "Disabling FusedAttention as no backend supports the provided input" + ) use_fused_attention = False # Filter: determinism. @@ -4055,71 +5198,84 @@ class DotProductAttention(torch.nn.Module): # Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path # on sm90 architectures. # - if (use_fused_attention + if ( + use_fused_attention and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and self.deterministic - and self.device_compute_capability != (9, 0)): + and self.device_compute_capability != (9, 0) + ): self.logger.debug("Disabling FusedAttention for determinism reasons") use_fused_attention = False # Select FusedAttention on sm90 and FlashAttention on others for performance - if (use_flash_attention + if ( + use_flash_attention and use_fused_attention - and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]): + and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + ): if self.device_compute_capability == (9, 0): self.logger.debug( "Disabling FlashAttention to give FusedAttention preference on Hopper+ " - "for performance reasons") + "for performance reasons" + ) use_flash_attention = False run_config = { - "compute_capability":"sm"+str((lambda x,y: x*10+y)( - self.device_compute_capability[0],self.device_compute_capability[1])), - "q_dtype":query_layer.dtype, - "k_dtype":key_layer.dtype, - "v_dtype":value_layer.dtype, - "q_shape":list(query_layer.shape), - "k_shape":list(key_layer.shape), - "v_shape":list(value_layer.shape), - "qkv_format":qkv_format, - "qkv_layout":qkv_layout, - "mask_type":attn_mask_type, - "bias_type":core_attention_bias_type, - "bias_shape":core_attention_bias.shape if core_attention_bias is not None else None, - "dropout":self.attention_dropout, - "context_parallel":context_parallel, - "is_training":self.training, - "transformer_engine_version":te.__version__, - "flash_attn_version":_flash_attn_version, - "cudnn_version":'.'.join([str(i) for i in get_cudnn_version()])} + "compute_capability": "sm" + + str( + (lambda x, y: x * 10 + y)( + self.device_compute_capability[0], self.device_compute_capability[1] + ) + ), + "q_dtype": query_layer.dtype, + "k_dtype": key_layer.dtype, + "v_dtype": value_layer.dtype, + "q_shape": list(query_layer.shape), + "k_shape": list(key_layer.shape), + "v_shape": list(value_layer.shape), + "qkv_format": qkv_format, + "qkv_layout": qkv_layout, + "mask_type": attn_mask_type, + "bias_type": core_attention_bias_type, + "bias_shape": core_attention_bias.shape if core_attention_bias is not None else None, + "dropout": self.attention_dropout, + "context_parallel": context_parallel, + "is_training": self.training, + "transformer_engine_version": te.__version__, + "flash_attn_version": _flash_attn_version, + "cudnn_version": ".".join([str(i) for i in get_cudnn_version()]), + } if use_flash_attention: self.logger.info("Running with FlashAttention backend ") - self.logger.debug("Running with config=%s",run_config) + self.logger.debug("Running with config=%s", run_config) if core_attention_bias_type == "alibi": alibi_slopes, _ = get_alibi( - query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes) - return self.flash_attention(query_layer, - key_layer, - value_layer, - attention_mask=attention_mask, - qkv_layout=qkv_layout, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - attn_mask_type=attn_mask_type, - window_size=window_size, - alibi_slopes=alibi_slopes, - cp_group=self.cp_group, - cp_global_ranks=self.cp_global_ranks, - cp_stream=self.cp_stream, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv) + query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes + ) + return self.flash_attention( + query_layer, + key_layer, + value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=self.cp_group, + cp_global_ranks=self.cp_global_ranks, + cp_stream=self.cp_stream, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) if use_fused_attention: self.logger.info( - "Running with FusedAttention backend (sub-backend %s)", - int(fused_attention_backend)) - self.logger.debug("Running with config=%s",run_config) + "Running with FusedAttention backend (sub-backend %s)", int(fused_attention_backend) + ) + self.logger.debug("Running with config=%s", run_config) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.fused_attention, @@ -4144,7 +5300,8 @@ class DotProductAttention(torch.nn.Module): cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, - is_first_microbatch=is_first_microbatch) + is_first_microbatch=is_first_microbatch, + ) return self.fused_attention( query_layer, key_layer, @@ -4167,46 +5324,52 @@ class DotProductAttention(torch.nn.Module): cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, - is_first_microbatch=is_first_microbatch) + is_first_microbatch=is_first_microbatch, + ) - assert (not context_parallel), \ - "Context parallelism is only implemented with Flash Attention and Fused Attention!" + assert ( + not context_parallel + ), "Context parallelism is only implemented with Flash Attention and Fused Attention!" from .cpu_offload import CPUOffloadEnabled + if CPUOffloadEnabled: warnings.warn( - "Attention activation Offloading is only implemented" - "with Flash Attention and Fused Attention!" - ) + "Attention activation Offloading is only implemented" + "with Flash Attention and Fused Attention!" + ) if use_unfused_attention: self.logger.info("Running with UnfusedDotProductAttention backend") - self.logger.debug("Running with config=%s",run_config) + self.logger.debug("Running with config=%s", run_config) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, query_layer, key_layer, value_layer, - qkv_layout = qkv_layout, - cu_seqlens_q = cu_seqlens_q, - cu_seqlens_kv = cu_seqlens_kv, - attn_mask_type = attn_mask_type, - attention_mask = attention_mask, - core_attention_bias_type = core_attention_bias_type, - core_attention_bias = core_attention_bias, - alibi_slopes = alibi_slopes) - return self.unfused_attention(query_layer, - key_layer, - value_layer, - qkv_layout = qkv_layout, - cu_seqlens_q = cu_seqlens_q, - cu_seqlens_kv = cu_seqlens_kv, - attn_mask_type = attn_mask_type, - attention_mask = attention_mask, - core_attention_bias_type = core_attention_bias_type, - core_attention_bias = core_attention_bias, - alibi_slopes = alibi_slopes) + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + alibi_slopes=alibi_slopes, + ) + return self.unfused_attention( + query_layer, + key_layer, + value_layer, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + alibi_slopes=alibi_slopes, + ) raise Exception("No dot product attention support for the provided inputs!") @@ -4382,7 +5545,7 @@ class MultiheadAttention(torch.nn.Module): bias: bool = True, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", - qkv_format: str = "sbhd" + qkv_format: str = "sbhd", ) -> None: super().__init__() @@ -4420,13 +5583,13 @@ class MultiheadAttention(torch.nn.Module): self.sequence_parallel = (tp_size > 1) and sequence_parallel self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size) - self.num_gqa_groups = ( - num_attention_heads if num_gqa_groups is None else num_gqa_groups - ) - assert (num_attention_heads % self.num_gqa_groups == 0 - ), "The number of attention heads must be divisible by the number of GQA groups!" - assert (self.num_gqa_groups % tp_size == 0 - ), "The number of GQA groups must be divisible by tensor parallel size!" + self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups + assert ( + num_attention_heads % self.num_gqa_groups == 0 + ), "The number of attention heads must be divisible by the number of GQA groups!" + assert ( + self.num_gqa_groups % tp_size == 0 + ), "The number of GQA groups must be divisible by tensor parallel size!" self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) self.hidden_size_per_attention_head = kv_channels @@ -4448,11 +5611,13 @@ class MultiheadAttention(torch.nn.Module): if self.attention_type == "self": parameters_split = None if not fuse_qkv_params: - parameters_split = collections.OrderedDict([ - ("query", self.hidden_size_q), - ("key", self.hidden_size_kv), - ("value", self.hidden_size_kv), - ]) + parameters_split = collections.OrderedDict( + [ + ("query", self.hidden_size_q), + ("key", self.hidden_size_kv), + ("value", self.hidden_size_kv), + ] + ) if self.input_layernorm: self.layernorm_qkv = LayerNormLinear( hidden_size, @@ -4555,7 +5720,6 @@ class MultiheadAttention(torch.nn.Module): **common_gemm_kwargs, ) - def _allocate_memory( self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype ) -> torch.Tensor: @@ -4694,13 +5858,14 @@ class MultiheadAttention(torch.nn.Module): window_size = self.window_size if "padding" in attn_mask_type and attention_mask is not None: - for i,_ in enumerate(attention_mask): + for i, _ in enumerate(attention_mask): assert ( attention_mask[i].dtype == torch.bool ), "Attention mask must be in boolean type!" - assert (core_attention_bias_type in AttnBiasTypes - ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" + assert ( + core_attention_bias_type in AttnBiasTypes + ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" # ================================================= # Pre-allocate memory for key-values for inference @@ -4745,11 +5910,12 @@ class MultiheadAttention(torch.nn.Module): mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + is_first_module_in_mha=True, # specific to FP8 MHA ) - num_queries_per_key_value = (self.num_attention_heads_per_partition // - self.num_gqa_groups_per_partition) + num_queries_per_key_value = ( + self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition + ) if self.qkv_weight_interleaved: # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( @@ -4764,7 +5930,7 @@ class MultiheadAttention(torch.nn.Module): new_tensor_shape = mixed_x_layer.size()[:-1] + ( (num_queries_per_key_value + 2), self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head + self.hidden_size_per_attention_head, ) # split along third last dimension split_dim = -3 @@ -4783,21 +5949,24 @@ class MultiheadAttention(torch.nn.Module): ) else: query_layer, key_layer, value_layer = torch.split( - mixed_x_layer, (num_queries_per_key_value, 1, 1), dim = split_dim, - ) + mixed_x_layer, + (num_queries_per_key_value, 1, 1), + dim=split_dim, + ) # query: -> [sq, b, np, hn] # key, value: -> [sq, b, ng, hn] - query_layer, key_layer, value_layer = (x.reshape(x.size(0), x.size(1), -1, - self.hidden_size_per_attention_head) - for x in (query_layer, key_layer, value_layer)) + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) elif self.attention_type == "cross": # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + is_first_module_in_mha=True, # specific to FP8 MHA ) if self.qkv_weight_interleaved: @@ -4822,15 +5991,25 @@ class MultiheadAttention(torch.nn.Module): # mixed_kv_layer --> 2 [sk, b, ng, hn] if not is_in_onnx_export_mode(): key_layer, value_layer = _SplitAlongDim.apply( - mixed_kv_layer, split_dim, mixed_kv_layer.shape[split_dim] // 2, + mixed_kv_layer, + split_dim, + mixed_kv_layer.shape[split_dim] // 2, ) else: key_layer, value_layer = torch.split( - mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim, + mixed_kv_layer, + mixed_kv_layer.shape[split_dim] // 2, + dim=split_dim, ) - key_layer, value_layer = (x.reshape( - x.size(0), x.size(1), -1, self.hidden_size_per_attention_head, - ) for x in (key_layer, value_layer)) + key_layer, value_layer = ( + x.reshape( + x.size(0), + x.size(1), + -1, + self.hidden_size_per_attention_head, + ) + for x in (key_layer, value_layer) + ) # Attention head [sq, b, h] --> [sq, b, hp] if self.input_layernorm: @@ -4846,7 +6025,7 @@ class MultiheadAttention(torch.nn.Module): query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + is_first_module_in_mha=True, # specific to FP8 MHA ) # [sq, b, hp] --> [sq, b, np, hn] @@ -4861,12 +6040,12 @@ class MultiheadAttention(torch.nn.Module): # ====================================================== if rotary_pos_emb is not None: - assert (not isinstance(query_layer, Float8Tensor) - and not isinstance(key_layer, Float8Tensor) - ), "RoPE is not supported for Float8Tensors!" + assert not isinstance(query_layer, Float8Tensor) and not isinstance( + key_layer, Float8Tensor + ), "RoPE is not supported for Float8Tensors!" # duplicate the pos_emb for self attention if not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = ((rotary_pos_emb,) * 2) + rotary_pos_emb = (rotary_pos_emb,) * 2 q_pos_emb, k_pos_emb = rotary_pos_emb diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index cc19159ca305094fcfd449736d2f340d631dd297..88d1aa1f5dbd83d6aac2ab6219ffd82a13784835 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -29,9 +29,22 @@ AttnTypes = ("self", "cross") AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias", "alibi") QKVLayouts = ( - "sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", - "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", - "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd") + "sb3hd", + "sbh3d", + "sbhd_sb2hd", + "sbhd_sbh2d", + "sbhd_sbhd_sbhd", + "bs3hd", + "bsh3d", + "bshd_bs2hd", + "bshd_bsh2d", + "bshd_bshd_bshd", + "t3hd", + "th3d", + "thd_t2hd", + "thd_th2d", + "thd_thd_thd", +) LayerTypes = ("encoder", "decoder") diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py index 375b728b436904471b830c44ced664c0f23e8abf..767fe25291c1815547506ad922c5815228f1591f 100644 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ b/transformer_engine/pytorch/cpp_extensions/activation.py @@ -8,7 +8,7 @@ import torch import transformer_engine_torch as tex -__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu', 'srelu'] +__all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] def gelu( @@ -167,6 +167,7 @@ def qgelu( otype, ) + def srelu( inp: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 294b32f10e9de24d668ded5d739bb3435a463032..2856d4727ba003404b45a77a8532529983223842 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -8,8 +8,7 @@ import torch import transformer_engine_torch as tex -__all__ = ['cast_to_fp8', - 'cast_from_fp8'] +__all__ = ["cast_to_fp8", "cast_from_fp8"] def cast_to_fp8( @@ -30,7 +29,7 @@ def cast_to_fp8( fp8_meta_tensor.amax_history, fp8_meta_tensor.scale_inv, fp8_tensor, - otype + otype, ) return None @@ -43,6 +42,7 @@ def cast_to_fp8( otype, ) + def cast_from_fp8( inp: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index f0b5052c24a37813e52c0b38c7704ef84e1a4864..6a6860391deeba1ddf3eca728b289f535b550280 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -11,16 +11,18 @@ from transformer_engine_torch import ( NVTE_QKV_Layout, NVTE_Bias_Type, NVTE_Mask_Type, - NVTE_Fused_Attn_Backend + NVTE_Fused_Attn_Backend, ) -__all__ = ['fused_attn_fwd_qkvpacked', - 'fused_attn_bwd_qkvpacked', - 'fused_attn_fwd_kvpacked', - 'fused_attn_bwd_kvpacked', - 'fused_attn_fwd', - 'fused_attn_bwd'] +__all__ = [ + "fused_attn_fwd_qkvpacked", + "fused_attn_bwd_qkvpacked", + "fused_attn_fwd_kvpacked", + "fused_attn_bwd_kvpacked", + "fused_attn_fwd", + "fused_attn_bwd", +] TORCH_DType = { @@ -48,28 +50,28 @@ QKVLayout = { "thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD, "thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D, "thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD, - } +} AttnBiasType = { "no_bias": NVTE_Bias_Type.NVTE_NO_BIAS, "pre_scale_bias": NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS, "post_scale_bias": NVTE_Bias_Type.NVTE_POST_SCALE_BIAS, "alibi": NVTE_Bias_Type.NVTE_ALIBI, - } +} AttnMaskType = { "no_mask": NVTE_Mask_Type.NVTE_NO_MASK, "padding": NVTE_Mask_Type.NVTE_PADDING_MASK, "causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK, "padding_causal": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK, - } +} FusedAttnBackend = { "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, "FP8": NVTE_Fused_Attn_Backend.NVTE_FP8, "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, - } +} BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 @@ -197,18 +199,20 @@ def fused_attn_fwd_qkvpacked( attn_scale = 1.0 / math.sqrt(d) if attn_bias_type not in ["no_bias", "alibi"]: - assert (attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert (attn_bias.dtype == qkv.dtype - ), "attn_bias tensor must be in the same dtype as qkv." + assert ( + attn_bias is not None + ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." + assert attn_bias.dtype == qkv.dtype, "attn_bias tensor must be in the same dtype as qkv." - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = (max_seqlen * max_seqlen - + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + rng_elts_per_thread = ( + max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_FP8_THREADS_PER_CTA # BF16/FP16 fused attention API from fmha_v2 if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: @@ -216,30 +220,45 @@ def fused_attn_fwd_qkvpacked( # FP8 fused attention API from fmha_v2 if fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = (max_seqlen * max_seqlen - + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert (d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert (d_scale_s is not None - ), "q_scale_s is required as an input for FP8 fused attention." - assert (q_scale_s is not None - ), "q_scale_s is required as an input for FP8 fused attention." - assert (q_scale_o is not None - ), "q_scale_o is required as an input for FP8 fused attention." - assert (amax_s is not None - ), "amax_s is required as an input for FP8 fused attention." - assert (amax_o is not None - ), "amax_o is required as an input for FP8 fused attention." + rng_elts_per_thread = ( + max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_FP8_THREADS_PER_CTA + + assert ( + d_scale_qkv is not None + ), "d_scale_qkv is required as an input for FP8 fused attention." + assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." + assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." + assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_fwd_qkvpacked( - max_seqlen, is_training, attn_scale, dropout, fast_zero_fill, - QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens, qkv, qkv_dtype, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, - rng_gen, rng_elts_per_thread, + max_seqlen, + is_training, + attn_scale, + dropout, + fast_zero_fill, + QKVLayout[qkv_layout], + AttnBiasType[attn_bias_type], + AttnMaskType[attn_mask_type], + cu_seqlens, + qkv, + qkv_dtype, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + d_scale_qkv, + d_scale_s, + q_scale_s, + q_scale_o, + amax_s, + amax_o, + attn_bias, + rng_gen, + rng_elts_per_thread, ) # out, aux_ctx_tensors @@ -360,35 +379,60 @@ def fused_attn_bwd_qkvpacked( d = qkv.size(-1) attn_scale = 1.0 / math.sqrt(d) - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert (len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." + assert ( + len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." - assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." - assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." - assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." - assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention." - assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." - assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." - assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." - assert (amax_dp is not None), "amax_dp is required for FP8 fused attention." - assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention." - assert (len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." + assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." + assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." + assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." + assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." + assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." + assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." + assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." + assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." + assert amax_dp is not None, "amax_dp is required for FP8 fused attention." + assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." + assert ( + len(aux_ctx_tensors) == 3 + ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_bwd_qkvpacked( - max_seqlen, attn_scale, dropout, fast_zero_fill, - QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, - q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, + max_seqlen, + attn_scale, + dropout, + fast_zero_fill, + QKVLayout[qkv_layout], + AttnBiasType[attn_bias_type], + AttnMaskType[attn_mask_type], + cu_seqlens, + qkv, + o, + d_o, + qkv_dtype, + dqkv_dtype, + aux_ctx_tensors, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + d_scale_qkv, + d_scale_s, + d_scale_o, + d_scale_do, + d_scale_dp, + q_scale_s, + q_scale_dp, + q_scale_dqkv, + amax_dp, + amax_dqkv, ) return output_tensors @@ -527,18 +571,20 @@ def fused_attn_fwd_kvpacked( attn_scale = 1.0 / math.sqrt(d) if attn_bias_type not in ["no_bias", "alibi"]: - assert (attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert (attn_bias.dtype == q.dtype - ), "attn_bias tensor must be in the same dtype as q and kv." + assert ( + attn_bias is not None + ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." + assert attn_bias.dtype == q.dtype, "attn_bias tensor must be in the same dtype as q and kv." - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv - + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + rng_elts_per_thread = ( + max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_FP8_THREADS_PER_CTA # BF16/FP16 fused attention API from fmha_v2 if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: @@ -546,30 +592,48 @@ def fused_attn_fwd_kvpacked( # FP8 fused attention API from fmha_v2 if fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = (max_seqlen_q * max_seqlen_q - + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert (d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert (d_scale_s is not None - ), "q_scale_s is required as an input for FP8 fused attention." - assert (q_scale_s is not None - ), "q_scale_s is required as an input for FP8 fused attention." - assert (q_scale_o is not None - ), "q_scale_o is required as an input for FP8 fused attention." - assert (amax_s is not None - ), "amax_s is required as an input for FP8 fused attention." - assert (amax_o is not None - ), "amax_o is required as an input for FP8 fused attention." + rng_elts_per_thread = ( + max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_FP8_THREADS_PER_CTA + + assert ( + d_scale_qkv is not None + ), "d_scale_qkv is required as an input for FP8 fused attention." + assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." + assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." + assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_fwd_kvpacked( - max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, - QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, - attn_bias, rng_gen, rng_elts_per_thread, + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + dropout, + fast_zero_fill, + QKVLayout[qkv_layout], + AttnBiasType[attn_bias_type], + AttnMaskType[attn_mask_type], + cu_seqlens_q, + cu_seqlens_kv, + q, + kv, + qkv_dtype, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + d_scale_qkv, + d_scale_s, + q_scale_s, + q_scale_o, + amax_s, + amax_o, + attn_bias, + rng_gen, + rng_elts_per_thread, ) # out, aux_ctx_tensors @@ -704,35 +768,63 @@ def fused_attn_bwd_kvpacked( d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert (len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." + assert ( + len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." - assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." - assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." - assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." - assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention." - assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." - assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." - assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." - assert (amax_dp is not None), "amax_dp is required for FP8 fused attention." - assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention." - assert (len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." + assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." + assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." + assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." + assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." + assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." + assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." + assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." + assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." + assert amax_dp is not None, "amax_dp is required for FP8 fused attention." + assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." + assert ( + len(aux_ctx_tensors) == 3 + ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_bwd_kvpacked( - max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, - QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, - q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, + max_seqlen_q, + max_seqlen_kv, + attn_scale, + dropout, + fast_zero_fill, + QKVLayout[qkv_layout], + AttnBiasType[attn_bias_type], + AttnMaskType[attn_mask_type], + cu_seqlens_q, + cu_seqlens_kv, + q, + kv, + o, + d_o, + qkv_dtype, + dqkv_dtype, + aux_ctx_tensors, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + d_scale_qkv, + d_scale_s, + d_scale_o, + d_scale_do, + d_scale_dp, + q_scale_s, + q_scale_dp, + q_scale_dqkv, + amax_dp, + amax_dqkv, ) return output_tensors @@ -878,18 +970,20 @@ def fused_attn_fwd( attn_scale = 1.0 / math.sqrt(d) if attn_bias_type not in ["no_bias", "alibi"]: - assert (attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert (attn_bias.dtype == q.dtype - ), "attn_bias tensor must be in the same dtype as q and kv." + assert ( + attn_bias is not None + ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." + assert attn_bias.dtype == q.dtype, "attn_bias tensor must be in the same dtype as q and kv." - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv - + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + rng_elts_per_thread = ( + max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_FP8_THREADS_PER_CTA # BF16/FP16 fused attention API from fmha_v2 if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: @@ -897,30 +991,49 @@ def fused_attn_fwd( # FP8 fused attention API from fmha_v2 if fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = (max_seqlen_q * max_seqlen_q - + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert (d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert (d_scale_s is not None - ), "q_scale_s is required as an input for FP8 fused attention." - assert (q_scale_s is not None - ), "q_scale_s is required as an input for FP8 fused attention." - assert (q_scale_o is not None - ), "q_scale_o is required as an input for FP8 fused attention." - assert (amax_s is not None - ), "amax_s is required as an input for FP8 fused attention." - assert (amax_o is not None - ), "amax_o is required as an input for FP8 fused attention." + rng_elts_per_thread = ( + max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_FP8_THREADS_PER_CTA + + assert ( + d_scale_qkv is not None + ), "d_scale_qkv is required as an input for FP8 fused attention." + assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." + assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." + assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_fwd( - max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, - QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, - attn_bias, rng_gen, rng_elts_per_thread, + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + dropout, + fast_zero_fill, + QKVLayout[qkv_layout], + AttnBiasType[attn_bias_type], + AttnMaskType[attn_mask_type], + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + qkv_dtype, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + d_scale_qkv, + d_scale_s, + q_scale_s, + q_scale_o, + amax_s, + amax_o, + attn_bias, + rng_gen, + rng_elts_per_thread, ) # out, aux_ctx_tensors @@ -1063,35 +1176,64 @@ def fused_attn_bwd( d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) - assert (fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert ( + fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert (len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." + assert ( + len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." - assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." - assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." - assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." - assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention." - assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." - assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." - assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." - assert (amax_dp is not None), "amax_dp is required for FP8 fused attention." - assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention." - assert (len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." + assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." + assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." + assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." + assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." + assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." + assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." + assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." + assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." + assert amax_dp is not None, "amax_dp is required for FP8 fused attention." + assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." + assert ( + len(aux_ctx_tensors) == 3 + ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_bwd( - max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, - QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, - seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, - q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, + max_seqlen_q, + max_seqlen_kv, + attn_scale, + dropout, + fast_zero_fill, + QKVLayout[qkv_layout], + AttnBiasType[attn_bias_type], + AttnMaskType[attn_mask_type], + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + o, + d_o, + qkv_dtype, + dqkv_dtype, + aux_ctx_tensors, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + seq_offsets_o, + d_scale_qkv, + d_scale_s, + d_scale_o, + d_scale_do, + d_scale_dp, + q_scale_s, + q_scale_dp, + q_scale_dqkv, + amax_dp, + amax_dqkv, ) return output_tensors diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index aa873459d42808a4f270e49cd2128eb4fb74e3e3..bb99e1d5ede1ec4fca621bc14a9e74a25bbffcd0 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -10,7 +10,7 @@ from ..constants import TE_DType from ..utils import assert_dim_for_fp8_exec -__all__ = ['gemm', 'fp8_gemm'] +__all__ = ["gemm", "fp8_gemm"] def fp8_gemm( @@ -27,7 +27,7 @@ def fp8_gemm( gelu: bool = False, accumulate: bool = False, out: Optional[torch.Tensor] = None, - out_index = None, + out_index=None, fp8_meta_tensor: tex.FP8TensorMeta = None, bias: Optional[torch.Tensor] = None, use_bias: bool = False, @@ -89,22 +89,35 @@ def fp8_gemm( workspace, workspace.shape[0], accumulate, - use_split_accumulator) + use_split_accumulator, + ) fn = torch.ops.tex_ts.te_gemm_ts if ub_algo is not None: - assert ub is not None, 'ub object is None!' + assert ub is not None, "ub object is None!" if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) - args = tuple(args + (1, extra_output_tensor,)) + args = tuple( + args + + ( + 1, + extra_output_tensor, + ) + ) elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) - args = tuple(args + (0, extra_output_tensor,)) + args = tuple( + args + + ( + 0, + extra_output_tensor, + ) + ) elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: fn = ub.split_overlap_ag_p2p extra_output_tensor = ( @@ -121,25 +134,35 @@ def fp8_gemm( fn = ub.split_overlap_rs assert ( extra_output_tensor is not None - ), 'SPLIT_PIPELINED_RS requires extra output tensor' - args = tuple(args + (True, extra_output_tensor,)) + ), "SPLIT_PIPELINED_RS requires extra output tensor" + args = tuple( + args + + ( + True, + extra_output_tensor, + ) + ) elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: fn = ub.split_overlap_rs_p2p assert ( extra_output_tensor is not None - ), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor' + ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" args = tuple(args + (extra_output_tensor,)) elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: fn = ub.atomic_gemm_overlap_rs - assert ( - extra_output_tensor is not None - ), 'ATOMIC_GEMM_RS requires extra output tensor' - args = tuple(args + (True, extra_output_tensor,)) + assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" + args = tuple( + args + + ( + True, + extra_output_tensor, + ) + ) elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P: fn = ub.atomic_gemm_overlap_rs_p2p assert ( extra_output_tensor is not None - ), 'ATOMIC_GEMM_RS_P2P requires extra output tensor' + ), "ATOMIC_GEMM_RS_P2P requires extra output tensor" args = tuple(args + (extra_output_tensor,)) if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: out = fn(*args) @@ -172,7 +195,7 @@ def gemm( transa = layout[0] == "T" transb = layout[1] == "T" empty_tensor = torch.Tensor() - fp8_index = -1 # dummy index + fp8_index = -1 # dummy index if out is None: out = torch.empty( @@ -196,8 +219,9 @@ def gemm( if A.nelement() == 0 or B.nelement() == 0: return out, grad_bias, gelu_input - assert A.dtype == dtype and B.dtype == dtype, \ - f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}' + assert ( + A.dtype == dtype and B.dtype == dtype + ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}" input_dtype = TE_DType[dtype] output_dtype = TE_DType[out.dtype] if use_bias: @@ -217,9 +241,9 @@ def gemm( input_dtype, transb, out, - empty_tensor, # out_scale + empty_tensor, # out_scale output_dtype, - empty_tensor, # out_amax + empty_tensor, # out_amax grad_bias if grad else bias, bias_dtype, gelu_input, @@ -231,7 +255,7 @@ def gemm( ) fn = torch.ops.tex_ts.te_gemm_ts if ub_algo is not None: - assert ub is not None, 'ub object is None!' + assert ub is not None, "ub object is None!" if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap args = tuple(args + (1, empty_tensor)) @@ -248,13 +272,19 @@ def gemm( fn = ub.split_overlap_rs assert ( extra_output_tensor is not None - ), 'SPLIT_PIPELINED_RS requires extra output tensor' - args = tuple(args + (False, extra_output_tensor,)) + ), "SPLIT_PIPELINED_RS requires extra output tensor" + args = tuple( + args + + ( + False, + extra_output_tensor, + ) + ) elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: fn = ub.split_overlap_rs_p2p assert ( extra_output_tensor is not None - ), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor' + ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" args = tuple(args + (extra_output_tensor,)) _ = fn(*args) diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py index 74386db9120ec6cdb71e4a6e0878670e4f72e323..dd90bb0b66c0183c9b0bc9042884d7666a6ba718 100644 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ b/transformer_engine/pytorch/cpp_extensions/normalization.py @@ -8,12 +8,14 @@ import torch import transformer_engine_torch as tex -__all__ = ['layernorm_fwd_fp8', - 'layernorm_fwd_fp8_inf', - 'layernorm_fwd_inf', - 'rmsnorm_fwd_fp8', - 'rmsnorm_fwd_fp8_inf', - 'rmsnorm_fwd_inf'] +__all__ = [ + "layernorm_fwd_fp8", + "layernorm_fwd_fp8_inf", + "layernorm_fwd_inf", + "rmsnorm_fwd_fp8", + "rmsnorm_fwd_fp8_inf", + "rmsnorm_fwd_inf", +] def layernorm_fwd_fp8( @@ -91,7 +93,8 @@ def layernorm_fwd_fp8_inf( fp8_tensor, otype, sm_margin, - zero_centered_gamma) + zero_centered_gamma, + ) return ret @@ -113,6 +116,7 @@ def layernorm_fwd_inf( zero_centered_gamma, ) + def rmsnorm_fwd_fp8( inp: torch.Tensor, weight: torch.Tensor, @@ -183,7 +187,8 @@ def rmsnorm_fwd_fp8_inf( fp8_tensor, otype, sm_margin, - zero_centered_gamma) + zero_centered_gamma, + ) return ret diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index 278f3ba8ef1bb8642959cf7e5533dbff8e119bee..de83bcd7f59e92b415189bc4fa0285d2c6eae2b4 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -9,10 +9,12 @@ import transformer_engine_torch as tex from ..constants import TE_DType -__all__ = ['fp8_cast_transpose_fused', - 'fp8_cast_transpose_bgrad_fused', - 'fp8_cast_transpose_bgrad_dgelu_fused', - 'fp8_transpose_bgrad_fused'] +__all__ = [ + "fp8_cast_transpose_fused", + "fp8_cast_transpose_bgrad_fused", + "fp8_cast_transpose_bgrad_dgelu_fused", + "fp8_transpose_bgrad_fused", +] def fp8_cast_transpose_fused( @@ -28,9 +30,7 @@ def fp8_cast_transpose_fused( return_outputs = False if transpose_out is None: - transpose_out = torch.empty( - inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8 - ) + transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8) return_outputs = True if cast_out is None: cast_out = torch.empty_like(inp, dtype=torch.uint8) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 0890ca5875bf3bce9a63b6db85872ee824456249..b07c6d3508416ae85b880e794d5e58de02cfb8c4 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -11,7 +11,7 @@ import torch from .float8_tensor import Float8Tensor -__all__ = ['get_cpu_offload_context'] +__all__ = ["get_cpu_offload_context"] CPUOffloadEnabled = False @@ -69,9 +69,8 @@ class CpuOffloadSavedTensorHook: self.inside_context = True torch._C._autograd._push_saved_tensors_default_hooks( - self.on_save_for_backward, - self.on_get_saved_tensor - ) + self.on_save_for_backward, self.on_get_saved_tensor + ) def __exit__(self, *args: Any): global CPUOffloadEnabled @@ -80,18 +79,21 @@ class CpuOffloadSavedTensorHook: self.inside_context = False torch._C._autograd._pop_saved_tensors_default_hooks() - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: """On save for backward.""" - raise NotImplementedError("`on_save_for_backward: Callable[[torch.Tensor], Any]`" - "is not implemented in CpuOffloadHook class. Inherit " - "this class and implement your custom hooks") + raise NotImplementedError( + "`on_save_for_backward: Callable[[torch.Tensor], Any]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: """On get saved tensor.""" - raise NotImplementedError("`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" - "is not implemented in CpuOffloadHook class. Inherit " - "this class and implement your custom hooks") + raise NotImplementedError( + "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): @@ -101,48 +103,48 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): and `tensor_pop` interface. How the offload-handler manages the offloading, recovering or prefetching timing is transparent to this hook. """ + def __init__( self, offload_handler: OffloadHandler, - handler_extra_kwargs: Optional[Dict[str,Any]] = None, + handler_extra_kwargs: Optional[Dict[str, Any]] = None, debug: bool = False, ) -> None: if handler_extra_kwargs is None: handler_extra_kwargs = {} self.debug: bool = debug self.offload_handler: OffloadHandler = offload_handler - self.handler_extra_kwargs: Dict[str,Any] = handler_extra_kwargs + self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs super().__init__() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - retrieve_identifier = self.offload_handler.tensor_push( - tensor, - **self.handler_extra_kwargs - ) + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) return retrieve_identifier def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - tensor = self.offload_handler.tensor_pop( - saved_state, - **self.handler_extra_kwargs - ) + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) return tensor class OffloadHandler: """A base class for CPU offload-handler.""" + def __init__(self) -> None: pass def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: """Tensor push.""" - raise NotImplementedError("`tensor_push is not implented in OffloadHandler class. " - "Inherit this class and implement your custom tensor_push.") + raise NotImplementedError( + "`tensor_push is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_push." + ) def tensor_pop(self, tensor_tag: Any, **kwargs): """Tensor pop.""" - raise NotImplementedError("`tensor_pop is not implented in OffloadHandler class. " - "Inherit this class and implement your custom tensor_pop.") + raise NotImplementedError( + "`tensor_pop is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_pop." + ) class GroupCommitFunction(torch.autograd.Function): @@ -151,6 +153,7 @@ class GroupCommitFunction(torch.autograd.Function): accomplish all synchronizations. Implementing it as a function is necessary because we need to actions in both forward and backward. """ + @staticmethod def forward(ctx, tensor, cpu_offload_handler): cpu_offload_handler.on_group_commit_forward() @@ -173,11 +176,10 @@ class SynchronizedGroupOffloadHandler(OffloadHandler): The device-to-host and host-to-device copying happen in the same stream as the computation kernels, thus the copying will block computation. """ - def __init__(self, - num_offload_group, - tensor_need_offloading_checker=(lambda _: True), - debug=False - ) -> None: + + def __init__( + self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False + ) -> None: super().__init__() self.num_offload_group = num_offload_group @@ -199,8 +201,8 @@ class SynchronizedGroupOffloadHandler(OffloadHandler): def on_group_commit_forward(self): """On group commit forward.""" # finishing up with updating current group and tensor count - self.current_group += 1 # increment - self.tensor_count_current_group = 0 # reset + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset def on_group_commit_backward(self): """On group commit backward.""" @@ -213,8 +215,12 @@ class SynchronizedGroupOffloadHandler(OffloadHandler): fp8_offload = isinstance(src_tensor, Float8Tensor) cpu_backup = torch.empty( - src_tensor.size(), dtype=torch.uint8 if fp8_offload else src_tensor.dtype, - layout=src_tensor.layout, device="cpu", pin_memory=pin_memory) + src_tensor.size(), + dtype=torch.uint8 if fp8_offload else src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) if fp8_offload: cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) @@ -237,8 +243,9 @@ class SynchronizedGroupOffloadHandler(OffloadHandler): tensor_tag = (self.current_group, self.tensor_count_current_group) self.tensor_count_current_group += 1 assert tensor_tag not in self.tensor_tag_to_state - if (self.current_group < self.num_offload_group - and self.tensor_need_offloading_checker(tensor)): + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( + tensor + ): state = SynchronizedGroupOffloadHandler.offload(tensor) self.tensor_tag_to_state[tensor_tag] = state else: @@ -262,16 +269,20 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): achieves better performance due to the overlapping. D2h and h2d copying are completely hidden behind computation if computation time of a layer is longer than host-device communication time. Bulk offloading with delay and bulk reloading - with prefetch are implemented. """ - def __init__(self, - num_offload_group, # must be <= actual number of groups (number of commits) - num_prefetch_group=1, - tensor_need_offloading_checker=(lambda t: True), - debug=False - ) -> None: - super().__init__(num_offload_group=num_offload_group, - tensor_need_offloading_checker=tensor_need_offloading_checker, - debug=debug) + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_prefetch_group=1, + tensor_need_offloading_checker=(lambda t: True), + debug=False, + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + debug=debug, + ) self.num_prefetch_group = num_prefetch_group # prepare for tensor buffer @@ -300,16 +311,18 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): else: tensor_buf = id_buf_map[tensor_id] allocate_new_buf = ( - tensor_buf.size() != tensor.size() - or tensor_buf.dtype != tensor.dtype + tensor_buf.size() != tensor.size() or tensor_buf.dtype != tensor.dtype ) if allocate_new_buf: # supposed to only execute once fp8_offload = isinstance(tensor, Float8Tensor) buffer = torch.empty( - tensor.size(), dtype=torch.uint8 if fp8_offload else tensor.dtype, - layout=tensor.layout, device=tensor.device) + tensor.size(), + dtype=torch.uint8 if fp8_offload else tensor.dtype, + layout=tensor.layout, + device=tensor.device, + ) if isinstance(tensor, Float8Tensor): id_buf_map[tensor_id] = Float8Tensor.make_like(tensor, data=buffer) @@ -318,11 +331,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): return id_buf_map[tensor_id] - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - torch_stray_tensor = isinstance(tensor,(torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor)) + torch_stray_tensor = isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ) if not torch_stray_tensor: # obtain a unique tensor tag @@ -330,22 +347,23 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): self.tensor_count_current_group += 1 assert tensor_tag not in self.tensor_tag_to_state - if (self.current_group < self.num_offload_group - and self.tensor_need_offloading_checker(tensor)): + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( + tensor + ): # first copy the tensor to tensorbuf, # so that the original tensor will not be deleted tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag) tensor_buf.copy_(tensor) - if hasattr(tensor,"weight_offloading"): + if hasattr(tensor, "weight_offloading"): tensor_buf.weight_offloading = True - if hasattr(tensor,"activation_offloading"): + if hasattr(tensor, "activation_offloading"): tensor_buf.activation_offloading = True # Here we just save it, and at commit, bulk_offload_group will handle it self.tensor_tag_to_state[tensor_tag] = tensor_buf else: self.tensor_tag_to_state[tensor_tag] = tensor else: - tensor_tag = (-1,self.torch_tensor_count) + tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 self.tensor_tag_to_state[tensor_tag] = tensor @@ -371,10 +389,10 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): # if offload, return the reference to cpu copy if self.tensor_need_offloading_checker(tensor_on_device): - if hasattr(tensor_on_device,"weight_offloading"): - delattr(tensor_on_device,"weight_offloading") - if hasattr(tensor_on_device,"activation_offloading"): - delattr(tensor_on_device,"activation_offloading") + if hasattr(tensor_on_device, "weight_offloading"): + delattr(tensor_on_device, "weight_offloading") + if hasattr(tensor_on_device, "activation_offloading"): + delattr(tensor_on_device, "activation_offloading") state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) self.tensor_tag_to_state[tensor_tag] = state @@ -406,7 +424,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): # during forward, the next_group_to_fetch always points to the min of # the last commited group, and the last offloaded group - self.next_group_to_fetch = min(self.current_group, self.num_offload_group -1) + self.next_group_to_fetch = min(self.current_group, self.num_offload_group - 1) super().on_group_commit_forward() @@ -441,12 +459,13 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ): # record the event in the compute stream, for h2d to wait torch.cuda.current_stream().record_event( - self.compute_stream_bwd_start_events[group_num_to_prefetch]) + self.compute_stream_bwd_start_events[group_num_to_prefetch] + ) # start of h2d should wait for the compute and the d2h self.h2d_stream.wait_event(self.compute_stream_bwd_start_events[group_num_to_prefetch]) - #recover tensors (copy back from host) + # recover tensors (copy back from host) self.bulk_reload_group(group_num_to_prefetch) # record an event for the backward of this layer to wait @@ -464,7 +483,8 @@ def get_cpu_offload_context( enabled: bool = False, num_layers: int = 1, offload_activations: bool = True, - offload_weights: bool = True): + offload_weights: bool = True, +): """ This function returns the CPU Offload context and the synchronizer function that needs to be used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. @@ -494,14 +514,14 @@ def get_cpu_offload_context( """ def tensor_need_offloading_checker_activations(tensor): - return hasattr(tensor,"activation_offloading") + return hasattr(tensor, "activation_offloading") # This includes the Gradient Accumulation Buffer def tensor_need_offloading_checker_weights(tensor): return hasattr(tensor, "weight_offloading") def tensor_need_offloading_checker_all(tensor): - return (hasattr(tensor,"activation_offloading") or hasattr(tensor, "weight_offloading")) + return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading") if offload_activations and offload_weights: tensor_need_offloading_checker = tensor_need_offloading_checker_all @@ -512,16 +532,17 @@ def get_cpu_offload_context( else: raise ValueError( "CPU Offloading is enabled while it is not " - "mentioned what to offload (weights/activations)") + "mentioned what to offload (weights/activations)" + ) cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( - num_offload_group=num_layers, - num_prefetch_group=1, - tensor_need_offloading_checker=tensor_need_offloading_checker - ) + num_offload_group=num_layers, + num_prefetch_group=1, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) def group_prefetch_offload_commit_async(tensor): - return group_prefetch_offload_commit(tensor,cpu_offload_handler) + return group_prefetch_offload_commit(tensor, cpu_offload_handler) if enabled: return ( diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 0f2a1949ffa92dc4710d20aa6ba847db71bcc2fa..99608cb735d738bcbc87196ddc128537ebd86327 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -7,14 +7,13 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ -#include -#include - #include #include #include #include #include +#include +#include #include #include #include @@ -22,19 +21,19 @@ #include "common/util/logging.h" #include "common/util/system.h" -#include "userbuffers/userbuffers.h" #include "extensions.h" +#include "userbuffers/userbuffers.h" #define HALF_BYTES 2 #define UB_MAX_SM 32 -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA Error at line %d: %s\n", __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA Error at line %d: %s\n", __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ } while (0) using namespace torch::indexing; @@ -46,7 +45,7 @@ namespace ubuf { static struct TorchCallbacks : torch::CustomClassHolder { bool initialized{false}; std::unordered_map gathered_tensors; - std::function allgather; + std::function allgather; std::function barrier; std::function free; } torch_callbacks; @@ -55,10 +54,8 @@ static struct TorchCallbacks : torch::CustomClassHolder { ** Helper function for setting Python callbacks to torch.distributed collectives. */ void set_ubuf_bootstrap_callbacks( - std::function allgather, - std::function barrier, - std::function free -) { + std::function allgather, + std::function barrier, std::function free) { torch_callbacks.allgather = allgather; torch_callbacks.barrier = barrier; torch_callbacks.free = free; @@ -71,9 +68,9 @@ void set_ubuf_bootstrap_callbacks( */ void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbytes, char *group) { assert(torch_callbacks.initialized); - auto localtensor = torch::from_blob( - localdata, {static_cast(localbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); + auto localtensor = + torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); auto globaltensor = torch_callbacks.allgather(localtensor, group); *globaldata = globaltensor.data_ptr(); torch_callbacks.gathered_tensors[*globaldata] = globaltensor; @@ -93,8 +90,7 @@ void ub_barrier(char *group) { void ub_free(void *ptr) { assert(torch_callbacks.initialized); auto i = torch_callbacks.gathered_tensors.find(ptr); - if (i == torch_callbacks.gathered_tensors.end()) - return; + if (i == torch_callbacks.gathered_tensors.end()) return; auto tensor = std::move(i->second); torch_callbacks.gathered_tensors.erase(i); torch_callbacks.free(tensor); @@ -150,8 +146,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); } else { create_communicator_grouped2(&_ub_comm, rank, world_size, tp_rank, tp_size, 1, 1, - &ub_alloc_copy_allgather, &ub_barrier, &ub_free, - 1, 1, tp_size, 1); + &ub_alloc_copy_allgather, &ub_barrier, &ub_free, 1, 1, tp_size, + 1); } comm_created = true; } @@ -207,14 +203,14 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ - std::vector - bulk_overlap(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, transformer_engine::DType B_type, - bool transb, at::Tensor D, at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, int comm_type, at::Tensor rs_output) { + std::vector bulk_overlap( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type, + at::Tensor rs_output) { _ub_comm->use_ce = use_ce; _ub_comm->sms = comm_sms; _ub_comm->cga_size = cga_size; @@ -254,11 +250,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { NVTE_ERROR("Not supported communication type."); } - if (A_scale_inverse.numel()) - A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) - B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; assert(pre_gelu_out.numel() == 0); te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale, @@ -313,11 +307,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); - if (A_scale_inverse.numel()) - A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) - B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; assert(pre_gelu_out.numel() == 0); @@ -424,11 +416,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); - if (A_scale_inverse.numel()) - A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) - B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; assert(pre_gelu_out.numel() == 0); @@ -538,8 +528,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } } for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA( - cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); } _ub_comm->sms = ori_sms; @@ -584,8 +573,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor &get_ubuf_output(int comm_type) { char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) - NVTE_ERROR("Invalid comm_type"); + if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); if (_comm_type == COMM_TYPE::RS) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; @@ -635,8 +623,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); } else { create_communicator_grouped2(&_ub_comm, rank, world_size, tp_rank, tp_size, 1, 1, - &ub_alloc_copy_allgather, &ub_barrier, &ub_free, - 1, 1, tp_size, 1); + &ub_alloc_copy_allgather, &ub_barrier, &ub_free, 1, 1, tp_size, + 1); } comm_created = true; } @@ -662,7 +650,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { } _ubuf = torch::from_blob( - _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); + _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); @@ -758,11 +746,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int *counter_ptr = reinterpret_cast(counter.data_ptr()); int workspace_size_chunk = workspaceSize / _stream_compute.size(); - if (A_scale_inverse.numel()) - A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) - B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; assert(pre_gelu_out.numel() == 0); @@ -789,14 +775,14 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (env_p != nullptr && env_p[0] == '1') { if (i == 0) { userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, - _ub_comm, _next_rank, _prev_rank, _tp_size, - counter_ptr, true, (cudaStream_t)_stream_recv); + _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, + true, (cudaStream_t)_stream_recv); } } else { - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, - _ub_comm, _next_rank, (cudaStream_t) _stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, - _ub_comm, _prev_rank, (cudaStream_t) _stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _next_rank, (cudaStream_t)_stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _prev_rank, (cudaStream_t)_stream_recv); producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv); } if (i == 0) { @@ -811,10 +797,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); - CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(), - _ubufs[_self_chunk_id].numel() * - _ubufs[_self_chunk_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); + CHECK_CUDA( + cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(), + _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } @@ -824,12 +810,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Copy the first GEMM output chunk to the end chunk position of D_buffer char *src_ptr = reinterpret_cast(D_buffer.data_ptr()); - CHECK_CUDA(cudaMemcpyAsync( - src_ptr + (D.numel() * D.element_size()), - src_ptr, - n_chunk * m * D.element_size(), - cudaMemcpyDeviceToDevice, - (cudaStream_t) stream_main)); + CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, + n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); // Return the last N rows of D_buffer torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); return D_return; @@ -871,11 +854,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); int workspace_size_chunk = workspaceSize / _stream_compute.size(); - if (A_scale_inverse.numel()) - A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) - B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); @@ -920,10 +901,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor output_chunk = torch::from_blob( output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options()); if (do_gelu) { - pre_gelu_out = torch::from_blob( - pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk * 2, m}, - pre_gelu_out.options()); + pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), + {n_chunk * 2, m}, pre_gelu_out.options()); } torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, @@ -967,10 +946,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor output_chunk = torch::from_blob( output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options()); if (do_gelu) { - pre_gelu_out = torch::from_blob( - pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk, m}, - pre_gelu_out.options()); + pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), + {n_chunk, m}, pre_gelu_out.options()); } torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, @@ -1001,8 +978,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { } } for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA( - cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); } CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); @@ -1014,18 +990,18 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { return D; } // split_overlap_ag -/* + /* ** Split ReduceScatter + GEMM using P2P communication */ void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor rs_output) { + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { _ub_comm->use_ce = use_ce; _ub_comm->sms = sms; _ub_comm->cga_size = cga_size; @@ -1038,11 +1014,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int *counter_ptr = reinterpret_cast(counter.data_ptr()); int workspace_size_chunk = workspaceSize / _stream_compute.size(); - if (A_scale_inverse.numel()) - A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) - B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; // Catch up the main stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); @@ -1052,11 +1026,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Atomic GEMM // Process GEMM chunks in the order that AG+GEMM places the output chunks. torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - _ubuf, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms, 0, _tp_size, true, counter); + torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); + te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, _ubuf, + D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_chunk, + workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, 0, _tp_size, + true, counter); // P2P communication chunk for (int i = 1; i < _tp_size; i++) { @@ -1068,13 +1042,13 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, - _ub_comm, send_rank, (cudaStream_t) _stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, - _ub_comm, recv_rank, (cudaStream_t) _stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + (cudaStream_t)_stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + (cudaStream_t)_stream_recv); } - CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0)); + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); // Reduce GEMM output chunks char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); @@ -1083,10 +1057,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, - _tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main); + _tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main); } else { torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); torch::sum_out(rs_output, reduce_buf, 0); } } @@ -1119,11 +1093,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); int workspace_size_chunk = workspaceSize / _stream_compute.size(); - if (A_scale_inverse.numel()) - A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) - B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; // Catch up the main stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); @@ -1131,47 +1103,46 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); } // GEMM and send/recv chunks for (int i = 0; i < _tp_size; i++) { // GEMM chunk int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - char* input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); + char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options()); // Store the last GEMM chunk output to the recieve buffer. - torch::Tensor workspace_chunk = torch::from_blob( - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); if (i == _tp_size - 1) { - at::cuda::setCurrentCUDAStream(stream_main); + at::cuda::setCurrentCUDAStream(stream_main); } else { - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); } te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); if (i > 0) { - // P2P communication chunk - int send_offset = comm_bytes * (i - 1); - int recv_offset = comm_bytes * (i - 1 + _tp_size); - int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t) _stream_compute[(i - 1) % _stream_compute.size()])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_send, _start_comm, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_recv, _start_comm, 0)); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, - _ub_comm, send_rank, (cudaStream_t) _stream_send); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, - _ub_comm, recv_rank, (cudaStream_t) _stream_recv); + // P2P communication chunk + int send_offset = comm_bytes * (i - 1); + int recv_offset = comm_bytes * (i - 1 + _tp_size); + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + CHECK_CUDA(cudaEventRecord( + _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + send_rank, (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + recv_rank, (cudaStream_t)_stream_recv); } } - CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0)); + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); // Reduce GEMM output chunks char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); @@ -1180,15 +1151,14 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, - _tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main); + _tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main); } else { torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); torch::sum_out(rs_output, reduce_buf, 0); } for (size_t i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA( - cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); } CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); @@ -1221,8 +1191,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor get_ubuf_output(int comm_type) { char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) - NVTE_ERROR("Invalid comm_type"); + if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); if (_comm_type == COMM_TYPE::RS) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; diff --git a/transformer_engine/pytorch/csrc/common.cu b/transformer_engine/pytorch/csrc/common.cu index 981bc3bf986036cb07c246bebfdc9fe5546f0a95..2d8e602c5b0879be84b0b49925e344e3fb967f95 100644 --- a/transformer_engine/pytorch/csrc/common.cu +++ b/transformer_engine/pytorch/csrc/common.cu @@ -7,146 +7,116 @@ #include "common.h" #include "transformer_engine/transformer_engine.h" - transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, - const std::string &fp8_recipe) { - // if e4m3 or hybrid + forward - if ( (fp8_recipe == "E4M3") || ( (fp8_recipe == "HYBRID") && e4m3_if_hybrid ) ) { - return transformer_engine::DType::kFloat8E4M3; - } - return transformer_engine::DType::kFloat8E5M2; + const std::string& fp8_recipe) { + // if e4m3 or hybrid + forward + if ((fp8_recipe == "E4M3") || ((fp8_recipe == "HYBRID") && e4m3_if_hybrid)) { + return transformer_engine::DType::kFloat8E4M3; + } + return transformer_engine::DType::kFloat8E5M2; } transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, - const NVTEShape& shape, - const transformer_engine::DType type) { + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type) { return transformer_engine::TensorWrapper(data_ptr, shape, type); } - transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type) { + void* data_ptr, const std::vector& shape, const transformer_engine::DType type) { return transformer_engine::TensorWrapper(data_ptr, shape, type); } - transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { - transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; + transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); + std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } - return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); + for (auto s : tensor.sizes()) { + shape.push_back(s); + } + return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type, - void* amax_ptr, - void* scale_ptr, - void* scale_inv_ptr) { - return transformer_engine::TensorWrapper(data_ptr, shape, type, - reinterpret_cast(amax_ptr), - reinterpret_cast(scale_ptr), - reinterpret_cast(scale_inv_ptr)); +transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, + const std::vector& shape, + const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, + void* scale_inv_ptr) { + return transformer_engine::TensorWrapper( + data_ptr, shape, type, reinterpret_cast(amax_ptr), + reinterpret_cast(scale_ptr), reinterpret_cast(scale_inv_ptr)); } - -transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, - at::Tensor amax, +transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv) { - transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - - for (auto s : tensor.sizes()) { - shape.push_back(s); - } - NVTE_CHECK(amax.scalar_type() == at::kFloat); - NVTE_CHECK(scale.scalar_type() == at::kFloat); - NVTE_CHECK(scale_inv.scalar_type() == at::kFloat); - - return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype, - amax.data_ptr(), - scale.data_ptr(), - scale_inv.data_ptr()); + transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); + std::vector shape; + + for (auto s : tensor.sizes()) { + shape.push_back(s); + } + NVTE_CHECK(amax.scalar_type() == at::kFloat); + NVTE_CHECK(scale.scalar_type() == at::kFloat); + NVTE_CHECK(scale_inv.scalar_type() == at::kFloat); + + return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); } - -size_t product(const std::vector &shape) { - size_t ret = 1; - for (auto s : shape) { - ret *= s; - } - return ret; +size_t product(const std::vector& shape) { + size_t ret = 1; + for (auto s : shape) { + ret *= s; + } + return ret; } - -at::Tensor allocateSpace(const std::vector& shape, - const transformer_engine::DType type, +at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, bool init_to_zeros) { - std::vector shape_int64(shape.begin(), shape.end()); - c10::IntArrayRef ar_shape(shape_int64); - if (init_to_zeros) { - return at::zeros(ar_shape, at::CUDA(GetATenDType(type))); - } else { - return at::empty(ar_shape, at::CUDA(GetATenDType(type))); - } + std::vector shape_int64(shape.begin(), shape.end()); + c10::IntArrayRef ar_shape(shape_int64); + if (init_to_zeros) { + return at::zeros(ar_shape, at::CUDA(GetATenDType(type))); + } else { + return at::empty(ar_shape, at::CUDA(GetATenDType(type))); + } } - -at::Tensor allocateSpace(const NVTEShape &shape, - const transformer_engine::DType type, +at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type, bool init_to_zeros) { - auto size = shape.ndim; - if (size == 2 && init_to_zeros) { - return at::zeros({static_cast(shape.data[0]), - static_cast(shape.data[1])}, - at::CUDA(GetATenDType(type))); - } else if (size == 2) { - return at::empty({static_cast(shape.data[0]), - static_cast(shape.data[1])}, - at::CUDA(GetATenDType(type))); - } else if (size == 1 && init_to_zeros) { - return at::zeros({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); - } else if (size == 1) { - return at::empty({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); - } - NVTE_CHECK(false, "Should never reach here! func: allocateSpace"); + auto size = shape.ndim; + if (size == 2 && init_to_zeros) { + return at::zeros({static_cast(shape.data[0]), static_cast(shape.data[1])}, + at::CUDA(GetATenDType(type))); + } else if (size == 2) { + return at::empty({static_cast(shape.data[0]), static_cast(shape.data[1])}, + at::CUDA(GetATenDType(type))); + } else if (size == 1 && init_to_zeros) { + return at::zeros({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); + } else if (size == 1) { + return at::empty({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); + } + NVTE_CHECK(false, "Should never reach here! func: allocateSpace"); } - -at::Tensor allocateTorchTensor(int M, - int N, - transformer_engine::DType dtype -) { - return at::empty({static_cast(M), static_cast(N)}, - at::CUDA(GetATenDType(dtype))); +at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype) { + return at::empty({static_cast(M), static_cast(N)}, + at::CUDA(GetATenDType(dtype))); } - -at::Tensor allocateTorchTensor(int M, - transformer_engine::DType dtype -) { - return at::empty({static_cast(M)}, - at::CUDA(GetATenDType(dtype))); +at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype) { + return at::empty({static_cast(M)}, at::CUDA(GetATenDType(dtype))); } void* getDataPtr(at::Tensor tensor, int offset) { - void* dptr = nullptr; - if (tensor.numel() > 0) { - dptr = tensor.data_ptr(); - } - if (dptr != nullptr && offset != 0) { - char* char_ptr = reinterpret_cast(dptr); - char_ptr += offset * tensor.element_size(); - dptr = reinterpret_cast(char_ptr); - } - return dptr; + void* dptr = nullptr; + if (tensor.numel() > 0) { + dptr = tensor.data_ptr(); + } + if (dptr != nullptr && offset != 0) { + char* char_ptr = reinterpret_cast(dptr); + char_ptr += offset * tensor.element_size(); + dptr = reinterpret_cast(char_ptr); + } + return dptr; } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 9741fb2032fa23d21a8fac46e59f5cc080b84743..aac693a4308a42c34d15d02e50b933ff97f21c5d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -7,19 +7,10 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ -#include -#include -#include -#include -#include -#include -#include - #include #include #include #include -#include #include #include #include @@ -30,10 +21,9 @@ #include #include #include - -#include "common/util/logging.h" #include #include +#include #include #include #include @@ -43,7 +33,17 @@ #include #include #include -#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/util/logging.h" namespace transformer_engine { @@ -51,140 +51,118 @@ namespace transformer_engine { // data for a single FP8 block, e.g. LayerNormLinear class FP8TensorMeta { public: - at::Tensor scale; - at::Tensor scale_inv; - at::Tensor amax_history; + at::Tensor scale; + at::Tensor scale_inv; + at::Tensor amax_history; }; // Used as named indices on the `scale`, `scale_inv`, // and `amax` tensors in the `FP8TensorMeta` class. enum FP8FwdTensors { - GEMM1_INPUT = 0, - GEMM1_WEIGHT = 1, - GEMM1_OUTPUT = 2, - GEMM2_INPUT = 3, - GEMM2_WEIGHT = 4, - GEMM2_OUTPUT = 5, - GEMM3_INPUT = 6, - GEMM3_WEIGHT = 7, - GEMM3_OUTPUT = 8 + GEMM1_INPUT = 0, + GEMM1_WEIGHT = 1, + GEMM1_OUTPUT = 2, + GEMM2_INPUT = 3, + GEMM2_WEIGHT = 4, + GEMM2_OUTPUT = 5, + GEMM3_INPUT = 6, + GEMM3_WEIGHT = 7, + GEMM3_OUTPUT = 8 }; // Used as named indices on the `scale`, `scale_inv`, // and `amax` tensors in the `FP8TensorMeta` class. enum FP8BwdTensors { - GRAD_OUTPUT1 = 0, - GRAD_INPUT1 = 1, - GRAD_OUTPUT2 = 2, - GRAD_INPUT2 = 3, - GRAD_OUTPUT3 = 4, - GRAD_INPUT3 = 5 + GRAD_OUTPUT1 = 0, + GRAD_INPUT1 = 1, + GRAD_OUTPUT2 = 2, + GRAD_INPUT2 = 3, + GRAD_OUTPUT3 = 4, + GRAD_INPUT3 = 5 }; - } // namespace transformer_engine - transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, - const std::string &fp8_recipe); - + const std::string& fp8_recipe); inline at::ScalarType GetATenDType(transformer_engine::DType t) { - switch (t) { - case transformer_engine::DType::kInt32: - return torch::kInt32; - case transformer_engine::DType::kInt64: - return torch::kInt64; - case transformer_engine::DType::kFloat32: - return at::kFloat; - case transformer_engine::DType::kFloat16: - return at::kHalf; - case transformer_engine::DType::kBFloat16: - return at::kBFloat16; - case transformer_engine::DType::kByte: - case transformer_engine::DType::kFloat8E4M3: - case transformer_engine::DType::kFloat8E5M2: - return at::kByte; - default: - NVTE_ERROR("Invalid type"); - } + switch (t) { + case transformer_engine::DType::kInt32: + return torch::kInt32; + case transformer_engine::DType::kInt64: + return torch::kInt64; + case transformer_engine::DType::kFloat32: + return at::kFloat; + case transformer_engine::DType::kFloat16: + return at::kHalf; + case transformer_engine::DType::kBFloat16: + return at::kBFloat16; + case transformer_engine::DType::kByte: + case transformer_engine::DType::kFloat8E4M3: + case transformer_engine::DType::kFloat8E5M2: + return at::kByte; + default: + NVTE_ERROR("Invalid type"); + } } - inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { - switch (t) { - case at::kHalf: - return transformer_engine::DType::kFloat16; - case at::kFloat: - return transformer_engine::DType::kFloat32; - case at::kBFloat16: - return transformer_engine::DType::kBFloat16; - case at::kBool: - return transformer_engine::DType::kByte; - case torch::kByte: - return transformer_engine::DType::kByte; - case torch::kInt32: - return transformer_engine::DType::kInt32; - case torch::kInt64: - return transformer_engine::DType::kInt64; - default: - NVTE_ERROR("Invalid type"); - } + switch (t) { + case at::kHalf: + return transformer_engine::DType::kFloat16; + case at::kFloat: + return transformer_engine::DType::kFloat32; + case at::kBFloat16: + return transformer_engine::DType::kBFloat16; + case at::kBool: + return transformer_engine::DType::kByte; + case torch::kByte: + return transformer_engine::DType::kByte; + case torch::kInt32: + return transformer_engine::DType::kInt32; + case torch::kInt64: + return transformer_engine::DType::kInt64; + default: + NVTE_ERROR("Invalid type"); + } } - inline transformer_engine::DType GetTransformerEngineDType(int DType_value) { - return static_cast(DType_value); + return static_cast(DType_value); } transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const std::vector& shape, - const transformer_engine::DType type -); + const transformer_engine::DType type); transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, - void* scale_ptr, - void* scale_inv_ptr -); - + void* amax_ptr, void* scale_ptr, + void* scale_inv_ptr); transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const NVTEShape& shape, - const transformer_engine::DType type -); - + const transformer_engine::DType type); transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor); -transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, - at::Tensor amax, +transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv); +size_t product(const std::vector& shape); -size_t product(const std::vector &shape); - -at::Tensor allocateSpace(const std::vector& shape, - const transformer_engine::DType type, +at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, bool init_to_zeros); -at::Tensor allocateSpace(const NVTEShape &shape, - const transformer_engine::DType type, +at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type, bool init_to_zeros = false); +at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype); -at::Tensor allocateTorchTensor(int M, - int N, - transformer_engine::DType dtype -); - - -at::Tensor allocateTorchTensor(int M, - transformer_engine::DType dtype -); +at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype); void* getDataPtr(at::Tensor tensor, int offset = 0); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1799622777d34f765c65979cce58cf284e676e0b..05c43ea293835ce93676d9aff3a72934f02ca4d3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -14,179 +14,95 @@ * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend( - const transformer_engine::DType q_dtype, - const transformer_engine::DType kv_dtype, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - float p_dropout, - size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim); +NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, + const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float p_dropout, + size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim); std::vector fused_attn_fwd_qkvpacked( - size_t max_seqlen, bool is_training, - float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens, - const at::Tensor QKV, - const transformer_engine::DType qkv_type, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional scale_S, - const c10::optional scale_O, - c10::optional amax_S, - c10::optional amax_O, - const c10::optional Bias, - const c10::optional rng_gen, - size_t rng_elts_per_thread); + size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional seq_offsets_o, + const c10::optional descale_QKV, const c10::optional descale_S, + const c10::optional scale_S, const c10::optional scale_O, + c10::optional amax_S, c10::optional amax_O, + const c10::optional Bias, const c10::optional rng_gen, + size_t rng_elts_per_thread); std::vector fused_attn_bwd_qkvpacked( - size_t max_seqlen, float attn_scale, - float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens, - const at::Tensor QKV, - const at::Tensor O, - const at::Tensor dO, - const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, - const std::vector Aux_CTX_Tensors, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional descale_O, - const c10::optional descale_dO, - const c10::optional descale_dP, - const c10::optional scale_S, - const c10::optional scale_dP, - const c10::optional scale_dQKV, - c10::optional amax_dP, - c10::optional amax_dQKV); + size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, + const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional descale_O, + const c10::optional descale_dO, const c10::optional descale_dP, + const c10::optional scale_S, const c10::optional scale_dP, + const c10::optional scale_dQKV, c10::optional amax_dP, + c10::optional amax_dQKV); std::vector fused_attn_fwd_kvpacked( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor KV, - const transformer_engine::DType qkv_type, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional scale_S, - const c10::optional scale_O, - c10::optional amax_S, - c10::optional amax_O, - const c10::optional Bias, - const c10::optional rng_gen, - size_t rng_elts_per_thread); + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional seq_offsets_o, + const c10::optional descale_QKV, const c10::optional descale_S, + const c10::optional scale_S, const c10::optional scale_O, + c10::optional amax_S, c10::optional amax_O, + const c10::optional Bias, const c10::optional rng_gen, + size_t rng_elts_per_thread); std::vector fused_attn_bwd_kvpacked( - size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor KV, - const at::Tensor O, - const at::Tensor dO, - const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, - const std::vector Aux_CTX_Tensors, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional descale_O, - const c10::optional descale_dO, - const c10::optional descale_dP, - const c10::optional scale_S, - const c10::optional scale_dP, - const c10::optional scale_dQKV, - c10::optional amax_dP, - c10::optional amax_dQKV); + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, + const at::Tensor KV, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional descale_O, + const c10::optional descale_dO, const c10::optional descale_dP, + const c10::optional scale_S, const c10::optional scale_dP, + const c10::optional scale_dQKV, c10::optional amax_dP, + c10::optional amax_dQKV); std::vector fused_attn_fwd( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor K, - const at::Tensor V, - const transformer_engine::DType qkv_type, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional scale_S, - const c10::optional scale_O, - c10::optional amax_S, - c10::optional amax_O, - const c10::optional Bias, - const c10::optional rng_gen, - size_t rng_elts_per_thread); + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const at::Tensor Q, const at::Tensor K, const at::Tensor V, + const transformer_engine::DType qkv_type, const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional scale_S, + const c10::optional scale_O, c10::optional amax_S, + c10::optional amax_O, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd( - size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor K, - const at::Tensor V, - const at::Tensor O, - const at::Tensor dO, - const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, - const std::vector Aux_CTX_Tensors, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional descale_O, - const c10::optional descale_dO, - const c10::optional descale_dP, - const c10::optional scale_S, - const c10::optional scale_dP, - const c10::optional scale_dQKV, - c10::optional amax_dP, - c10::optional amax_dQKV); + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, + const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional descale_O, + const c10::optional descale_dO, const c10::optional descale_dP, + const c10::optional scale_S, const c10::optional scale_dP, + const c10::optional scale_dQKV, c10::optional amax_dP, + c10::optional amax_dQKV); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); @@ -195,118 +111,55 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); * GEMM **************************************************************************************************/ -void te_gemm(at::Tensor A, - at::Tensor A_scale_inverse, - transformer_engine::DType A_type, - bool transa, - at::Tensor B, - at::Tensor B_scale_inverse, - transformer_engine::DType B_type, - bool transb, - at::Tensor D, - at::Tensor D_scale, - transformer_engine::DType D_type, - at::Tensor D_amax, - at::Tensor bias, - transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, - bool grad, - at::Tensor workspace, - size_t workspaceSize, - bool accumulate, - bool use_split_accumulator, - int math_sm_count -); - -void te_atomic_gemm(at::Tensor A, - at::Tensor A_scale_inverse, - transformer_engine::DType A_type, - bool transa, - at::Tensor B, - at::Tensor B_scale_inverse, - transformer_engine::DType B_type, - bool transb, - at::Tensor D, - at::Tensor D_scale, - transformer_engine::DType D_type, - at::Tensor D_amax, - at::Tensor bias, - transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, - bool grad, - at::Tensor workspace, - size_t workspaceSize, - bool accumulate, - bool use_split_accumulator, - int math_sm_count, - int m_split, - int n_split, - bool gemm_producer, - at::Tensor counter -); +void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, + bool transa, at::Tensor B, at::Tensor B_scale_inverse, + transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count); + +void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, + bool transa, at::Tensor B, at::Tensor B_scale_inverse, + transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count, int m_split, int n_split, + bool gemm_producer, at::Tensor counter); /*************************************************************************************************** * Transpose **************************************************************************************************/ -void fused_cast_transpose(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - at::Tensor input_cast, - at::Tensor input_transpose, - transformer_engine::DType otype -); - - -void fused_cast_transpose_noop(at::Tensor input, - at::Tensor noop, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - at::Tensor input_cast, - at::Tensor input_transpose, - transformer_engine::DType otype, - int scale_offset = 0, - int amax_offset = 0, - int scale_inv_offset = 0 -); - - -std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset = 0, - int amax_offset = 0, - int scale_inv_offset = 0 -); +void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + at::Tensor input_cast, at::Tensor input_transpose, + transformer_engine::DType otype); + +void fused_cast_transpose_noop(at::Tensor input, at::Tensor noop, at::Tensor scale, at::Tensor amax, + at::Tensor scale_inv, at::Tensor input_cast, + at::Tensor input_transpose, transformer_engine::DType otype, + int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); +std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, + int scale_offset = 0, int amax_offset = 0, + int scale_inv_offset = 0); -std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, +std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, transformer_engine::DType grad_bias_type, - int scale_offset = 0, - int amax_offset = 0, - int scale_inv_offset = 0 -); - + int scale_offset = 0, int amax_offset = 0, + int scale_inv_offset = 0); std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, - at::Tensor gelu_input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, + at::Tensor gelu_input, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - int scale_offset = 0, - int amax_offset = 0, - int scale_inv_offset = 0 -); - + int scale_offset = 0, int amax_offset = 0, + int scale_inv_offset = 0); void fused_multi_cast_transpose(std::vector input_list, std::vector scale_list, @@ -314,350 +167,171 @@ void fused_multi_cast_transpose(std::vector input_list, std::vector transposed_output_list, std::vector amax_output_list, std::vector scale_inv_output_list, - transformer_engine::DType otype -); + transformer_engine::DType otype); +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype); -at::Tensor fp8_transpose(at::Tensor input, - transformer_engine::DType otype -); +void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype); -void fp8_transpose_noalloc(at::Tensor input, - at::Tensor output, - transformer_engine::DType otype -); - -void fp8_transpose_noalloc_noop(at::Tensor input, - at::Tensor output, - at::Tensor noop, - transformer_engine::DType otype -); +void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, + transformer_engine::DType otype); /*************************************************************************************************** * Activations **************************************************************************************************/ -at::Tensor gelu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -); - -at::Tensor relu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -); - -at::Tensor geglu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -); - -at::Tensor reglu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -); - -at::Tensor swiglu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -); - -at::Tensor qgelu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -); - -at::Tensor srelu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -); - -at::Tensor dgelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -); - -at::Tensor drelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -); - -at::Tensor dgeglu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -); - -at::Tensor dreglu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -); - -at::Tensor dswiglu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -); - -at::Tensor dqgelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -); - -at::Tensor dsrelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -); +at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype); + +at::Tensor relu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype); + +at::Tensor geglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype); + +at::Tensor reglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype); + +at::Tensor swiglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype); + +at::Tensor qgelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype); + +at::Tensor srelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype); + +at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); + +at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); + +at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); + +at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); + +at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); + +at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); + +at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ -std::vector layernorm_bwd(const at::Tensor &dz, - const at::Tensor &x, - const at::Tensor &mu, - const at::Tensor &rsigma, - const at::Tensor &gamma, - const int sm_margin, - const bool zero_centered_gamma -); - - -std::vector layernorm_fwd_fp8(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &mu, const at::Tensor &rsigma, + const at::Tensor &gamma, const int sm_margin, + const bool zero_centered_gamma); + +std::vector layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, float eps, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma, - const int scale_offset = 0, - const int amax_offset = 0, - const int scale_inv_offset = 0 -); - -std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - at::Tensor scale, - at::Tensor ln_out, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset = 0, - const int amax_offset = 0, - const int scale_inv_offset = 0 -); - -at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset = 0, - const int amax_offset = 0, - const int scale_inv_offset = 0 -); - -std::vector layernorm_fwd(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - const int sm_margin, - const bool zero_centered_gamma -); - -std::vector layernorm_fwd_noalloc(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - at::Tensor ln_out, - float eps, - const int sm_margin, - const bool zero_centered_gamma -); - -at::Tensor layernorm_fwd_inf(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - const int sm_margin, - const bool zero_centered_gamma -); + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); + +std::vector layernorm_fwd_fp8_noalloc( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, + at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma, + const int scale_offset = 0, const int amax_offset = 0, const int scale_inv_offset = 0); + +at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, float eps, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma, const int scale_offset = 0, + const int amax_offset = 0, const int scale_inv_offset = 0); + +std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, float eps, const int sm_margin, + const bool zero_centered_gamma); + +std::vector layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::Tensor ln_out, float eps, + const int sm_margin, const bool zero_centered_gamma); + +at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, float eps, const int sm_margin, + const bool zero_centered_gamma); /*************************************************************************************************** * RMSNorm **************************************************************************************************/ -std::vector rmsnorm_bwd(const at::Tensor &dz, - const at::Tensor &x, - const at::Tensor &rsigma, - const at::Tensor &gamma, - const int sm_margin, - const bool zero_centered_gamma -); - - -std::vector rmsnorm_fwd_fp8(const at::Tensor &input, - const at::Tensor &weight, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset = 0, - const int amax_offset = 0, - const int scale_inv_offset = 0 -); - -std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, - const at::Tensor &weight, - float eps, - at::Tensor scale, - at::Tensor ln_out, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset = 0, - const int amax_offset = 0, - const int scale_inv_offset = 0 -); - -at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, - const at::Tensor &weight, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset = 0, - const int amax_offset = 0, - const int scale_inv_offset = 0 -); - -std::vector rmsnorm_fwd(const at::Tensor &input, - const at::Tensor &weight, - float eps, - const int sm_margin, - const bool zero_centered_gamma -); - -std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, - const at::Tensor &weight, - at::Tensor ln_out, - float eps, - const int sm_margin, - const bool zero_centered_gamma -); - -at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, - const at::Tensor &weight, - float eps, - const int sm_margin, - const bool zero_centered_gamma -); +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &rsigma, const at::Tensor &gamma, + const int sm_margin, const bool zero_centered_gamma); -/*************************************************************************************************** - * Cast - **************************************************************************************************/ +std::vector rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, + float eps, at::Tensor scale, at::Tensor amax, + at::Tensor scale_inv, transformer_engine::DType otype, + const int sm_margin, const bool zero_centered_gamma, + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); -at::Tensor cast_to_fp8(const at::Tensor &input, - const at::Tensor &scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -); +std::vector rmsnorm_fwd_fp8_noalloc( + const at::Tensor &input, const at::Tensor &weight, float eps, at::Tensor scale, + at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int sm_margin, const bool zero_centered_gamma, const int scale_offset = 0, + const int amax_offset = 0, const int scale_inv_offset = 0); +at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps, + at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma, const int scale_offset = 0, + const int amax_offset = 0, const int scale_inv_offset = 0); -void cast_to_fp8_noalloc(const at::Tensor &input, - const at::Tensor &scale, - at::Tensor output, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -); +std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, + const int sm_margin, const bool zero_centered_gamma); +std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, + at::Tensor ln_out, float eps, const int sm_margin, + const bool zero_centered_gamma); -at::Tensor cast_from_fp8(const at::Tensor &input, - const at::Tensor &scale_inv, - transformer_engine::DType itype, - transformer_engine::DType otype -); +at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, + const int sm_margin, const bool zero_centered_gamma); /*************************************************************************************************** - * Softmax + * Cast **************************************************************************************************/ -at::Tensor scaled_softmax_forward(at::Tensor input, - float scale_factor -); +at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, + at::Tensor scale_inv, transformer_engine::DType otype); +void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype); -at::Tensor scaled_softmax_backward(at::Tensor output_grad_, - at::Tensor softmax_results_, - float scale_factor -); +at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, + transformer_engine::DType itype, transformer_engine::DType otype); +/*************************************************************************************************** + * Softmax + **************************************************************************************************/ -at::Tensor scaled_masked_softmax_forward(at::Tensor input, - at::Tensor mask, - float scale_factor -); - +at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor); -at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, - at::Tensor softmax_results_, - float scale_factor -); +at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, + float scale_factor); +at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor); -at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, - float scale_factor -); +at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, + float scale_factor); +at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor); at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, at::Tensor softmax_results_, - float scale_factor -); - - -at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, - float scale_factor -); + float scale_factor); +at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor); at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_, at::Tensor softmax_results_, - float scale_factor -); + float scale_factor); /*************************************************************************************************** * FP8 recipe @@ -668,32 +342,23 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio std::vector scales, std::vector scale_invs, const std::string &amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin); + transformer_engine::DType fp8_dtype, float margin); /*************************************************************************************************** * Rotary positional embedding **************************************************************************************************/ -at::Tensor fused_rope_forward(const at::Tensor &input, - const at::Tensor &freqs, - const bool transpose_output_memory -); +at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, + const bool transpose_output_memory); -at::Tensor fused_rope_backward(const at::Tensor &output_grads, - const at::Tensor &freqs, - const bool transpose_output_memory -); +at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const bool transpose_output_memory); -at::Tensor fused_rope_thd_forward(const at::Tensor &input, - const at::Tensor &cu_seqlens, - const at::Tensor &freqs -); +at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, + const at::Tensor &freqs); -at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, - const at::Tensor &cu_seqlens, - const at::Tensor &freqs -); +at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, + const at::Tensor &freqs); /*************************************************************************************************** * Miscellaneous @@ -705,48 +370,29 @@ size_t get_cudnn_version(); void placeholder(); - /*************************************************************************************************** * Support THD format for Context Parallel **************************************************************************************************/ -at::Tensor thd_read_half_tensor(const at::Tensor &tensor, - const at::Tensor &cu_seqlens, - int half_idx -); - -void thd_second_half_lse_correction(at::Tensor lse, - const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, - int total_tokens -); - -at::Tensor thd_read_second_half_lse(const at::Tensor &lse, - const at::Tensor &cu_seqlens, - int total_tokens -); - -void thd_out_correction(at::Tensor out, - const at::Tensor &out_per_step, - const at::Tensor &lse, - const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, - bool only_second_half -); - -void thd_grad_correction(at::Tensor grad, - const at::Tensor &grad_per_step, - const at::Tensor &cu_seqlens, - const std::string &first_half, - const std::string &second_half -); - -at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, - int total_tokens, - int world_size, - int rank -); +at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, + int half_idx); + +void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, int total_tokens); + +at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, + int total_tokens); + +void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, + const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, + bool only_second_half); + +void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, const std::string &first_half, + const std::string &second_half); +at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, + int world_size, int rank); /*************************************************************************************************** * multi_tensor_* kernels diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cu b/transformer_engine/pytorch/csrc/extensions/activation.cu index 8b3a3ba3383c35aff330e3d956abbee5644c5094..7f8cff5584a2fe4cc94d5d361cc384227eda299e 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cu +++ b/transformer_engine/pytorch/csrc/extensions/activation.cu @@ -6,51 +6,37 @@ #include "extensions.h" -at::Tensor gelu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { +at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -at::Tensor dgelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { +at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); nvte_dgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); @@ -58,51 +44,37 @@ at::Tensor dgelu(at::Tensor grad, return output; } -at::Tensor relu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { +at::Tensor relu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = static_cast(input.numel()) / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); nvte_relu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -at::Tensor drelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { +at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); nvte_drelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); @@ -110,51 +82,38 @@ at::Tensor drelu(at::Tensor grad, return output; } -at::Tensor geglu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { +at::Tensor geglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N / 2, - otype); + auto output = allocateTorchTensor(M, N / 2, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = + makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); nvte_geglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -at::Tensor dgeglu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { +at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); nvte_dgeglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); @@ -162,51 +121,38 @@ at::Tensor dgeglu(at::Tensor grad, return output; } -at::Tensor reglu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { +at::Tensor reglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N / 2, - otype); + auto output = allocateTorchTensor(M, N / 2, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = + makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); nvte_reglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -at::Tensor dreglu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { +at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); nvte_dreglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); @@ -214,51 +160,38 @@ at::Tensor dreglu(at::Tensor grad, return output; } -at::Tensor swiglu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { +at::Tensor swiglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N / 2, - otype); + auto output = allocateTorchTensor(M, N / 2, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = + makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); nvte_swiglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -at::Tensor dswiglu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { +at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); @@ -266,46 +199,32 @@ at::Tensor dswiglu(at::Tensor grad, return output; } -at::Tensor qgelu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { +at::Tensor qgelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); nvte_qgelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -at::Tensor dqgelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { +at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); auto gtype = GetTransformerEngineDType(grad.scalar_type()); @@ -318,55 +237,40 @@ at::Tensor dqgelu(at::Tensor grad, return output; } -at::Tensor srelu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { +at::Tensor srelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = static_cast(input.numel()) / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); nvte_srelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -at::Tensor dsrelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { +at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; size_t N = static_cast(input.size(-1)); size_t M = input.numel() / N; - auto output = - allocateTorchTensor(M, - N, - otype); + auto output = allocateTorchTensor(M, N, otype); auto itype = GetTransformerEngineDType(input.scalar_type()); auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); nvte_dsrelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } - diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu index 455d152fe81851e9e429f70a3c89e7a39d7ba37d..c58ba91d5e05466a8042db7f3ef40881236484ff 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu @@ -57,23 +57,20 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); - nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, - b, h, d, d2, stride_s, stride_b, stride_h, stride_d, - o_stride_s, o_stride_b, o_stride_h, o_stride_d, - at::cuda::getCurrentCUDAStream()); + nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2, + stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); return output; } -at::Tensor fused_rope_backward(const at::Tensor &output_grads, - const at::Tensor &freqs, +at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const bool transpose_output_memory) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK( - output_grads.size(0) <= freqs.size(0), - "expected freqs tensor has a longer sequence length than output_grads"); + TORCH_CHECK(output_grads.size(0) <= freqs.size(0), + "expected freqs tensor has a longer sequence length than output_grads"); TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, "expected the second and third dims of the freqs tensor equal 1"); TORCH_CHECK(output_grads.size(3) >= freqs.size(3), @@ -116,16 +113,14 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); - nvte_fused_rope_backward( - output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h, - d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h, + d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; } -at::Tensor fused_rope_thd_forward(const at::Tensor &input, - const at::Tensor &cu_seqlens, +at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, const at::Tensor &freqs) { using namespace transformer_engine; TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); @@ -169,16 +164,14 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); - nvte_fused_rope_thd_forward( - input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(), - max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, - o_stride_d, at::cuda::getCurrentCUDAStream()); + nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + output_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); return output; } -at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, - const at::Tensor &cu_seqlens, +at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, const at::Tensor &freqs) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); @@ -220,10 +213,10 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); - nvte_fused_rope_thd_backward( - output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, + stride_d, o_stride_t, o_stride_h, o_stride_d, + at::cuda::getCurrentCUDAStream()); return input_grads; } diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 6ef10e6b67fce9622b9a9a70e7439275563068a5..84b071b7e39f0e793e815da2874a5d82c0b77dab 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -10,35 +10,29 @@ constexpr int block_size = 512; constexpr int ctas_per_sm = 4; // get the fused attention backend -NVTE_Fused_Attn_Backend get_fused_attn_backend( - const transformer_engine::DType q_dtype, - const transformer_engine::DType kv_dtype, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - float p_dropout, - size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim) { +NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, + const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float p_dropout, + size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim) { NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), - qkv_layout, bias_type, attn_mask_type, p_dropout, - num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim); + nvte_get_fused_attn_backend(static_cast(q_dtype), static_cast(kv_dtype), + qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, + num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim); return fused_attention_backend; } // fast zero-fills of tensors template -__global__ void __launch_bounds__(block_size) mha_fill_kernel(scalar_t* out_tensor, - const int32_t* const start_row, - const size_t num_rows) { +__global__ void __launch_bounds__(block_size) + mha_fill_kernel(scalar_t *out_tensor, const int32_t *const start_row, const size_t num_rows) { size_t row_stride = gridDim.y * blockDim.x; size_t row_index = blockIdx.x + static_cast(start_row[0]); size_t col_index = blockIdx.y * blockDim.x + threadIdx.x; while (row_index < num_rows) { - out_tensor[row_index*row_stride + col_index] = 0; + out_tensor[row_index * row_stride + col_index] = 0; row_index += gridDim.x; } } @@ -56,22 +50,20 @@ void mha_fill(const at::Tensor &self, const at::Tensor &start_index) { dim3 dim_grid(num_blk_x, num_blk_y); dim3 dim_block(block_size); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, - self_2d.scalar_type(), "mha_fill", [&]() { - mha_fill_kernel<<>>( - self_2d.data_ptr(), - static_cast(start_index.data_ptr()), - max_tokens); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + at::ScalarType::Half, at::ScalarType::BFloat16, self_2d.scalar_type(), "mha_fill", [&]() { + mha_fill_kernel<<>>( + self_2d.data_ptr(), static_cast(start_index.data_ptr()), + max_tokens); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } // extract seed and offset from PhiloxCudaState -__global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { +__global__ void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) { if (arg.captured_) { rng_state_ptr[0] = static_cast(*arg.seed_.ptr); - rng_state_ptr[1] = static_cast( - *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); + rng_state_ptr[1] = + static_cast(*(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); } else { rng_state_ptr[0] = static_cast(arg.seed_.val); rng_state_ptr[1] = static_cast(arg.offset_.val); @@ -79,9 +71,7 @@ __global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { } // extract PhiloxCudaState from CUDA random number generator -at::PhiloxCudaState init_philox_state( - at::CUDAGeneratorImpl* gen, - size_t elts_per_thread) { +at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_per_thread) { at::PhiloxCudaState philox_args; std::lock_guard lock(gen->mutex_); philox_args = gen->philox_cuda_state(elts_per_thread); @@ -90,25 +80,16 @@ at::PhiloxCudaState init_philox_state( // fused attention FWD with packed QKV std::vector fused_attn_fwd_qkvpacked( - size_t max_seqlen, bool is_training, float attn_scale, - float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens, - const at::Tensor QKV, - const transformer_engine::DType qkv_type, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional scale_S, - const c10::optional scale_O, - c10::optional amax_S, - c10::optional amax_O, - const c10::optional Bias, - const c10::optional rng_gen, - size_t rng_elts_per_thread) { + size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional seq_offsets_o, + const c10::optional descale_QKV, const c10::optional descale_S, + const c10::optional scale_S, const c10::optional scale_O, + c10::optional amax_S, c10::optional amax_O, + const c10::optional Bias, const c10::optional rng_gen, + size_t rng_elts_per_thread) { using namespace transformer_engine; auto qkv_sizes = QKV.sizes().vec(); @@ -132,81 +113,74 @@ std::vector fused_attn_fwd_qkvpacked( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero - && ((h * d) % block_size == 0) - && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && ((h * d) % block_size == 0) && + (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { O.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!scale_S.has_value()) || (!scale_O.has_value()) - || (!amax_S.has_value()) || (!amax_O.has_value())) { + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || + (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), + scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { // BF16 or FP16 - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, - qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); + te_QKV = + makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, - DType::kFloat32, nullptr, nullptr, nullptr); + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, + nullptr, nullptr, nullptr); } auto cu_seqlens_sizes = cu_seqlens.sizes().vec(); std::vector cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()}; te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - if ((seq_offsets_q.has_value()) - && (seq_offsets_k.has_value()) - && (seq_offsets_v.has_value()) - && (seq_offsets_o.has_value())) { - auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); - std::vector seq_offsets_q_shape{ - seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; - auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); - std::vector seq_offsets_k_shape{ - seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; - auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); - std::vector seq_offsets_v_shape{ - seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; - auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); - std::vector seq_offsets_o_shape{ - seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; - te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), - seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), - seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), - seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), - seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); + DType::kInt32, nullptr, nullptr, nullptr); + + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && + (seq_offsets_o.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; + te_seq_offsets_q = + makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = + makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = + makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = + makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape, + DType::kInt32, nullptr, nullptr, nullptr); } // extract random number generator seed and offset auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); + philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); // create auxiliary output tensors @@ -218,51 +192,40 @@ std::vector fused_attn_fwd_qkvpacked( // populate tensors with appropriate shapes and dtypes nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - te_rng_state.data(), - max_seqlen, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, + te_cu_seqlens.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen, + is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), + at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; output_tensors.push_back(O); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors at::Tensor output_tensor; if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - output_tensor = (i < nvte_aux_tensor_pack.size-1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + if (i < nvte_aux_tensor_pack.size - 2) { + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + } else if (i == nvte_aux_tensor_pack.size - 2) { + output_tensor = rng_state; + } else if (i == nvte_aux_tensor_pack.size - 1) { + output_tensor = Bias.value(); } + } else { + output_tensor = (i < nvte_aux_tensor_pack.size - 1) + ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) + : rng_state; + } } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); } output_tensors.push_back(output_tensor); tensor->data.dptr = output_tensor.data_ptr(); @@ -270,22 +233,11 @@ std::vector fused_attn_fwd_qkvpacked( // execute the kernel nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - te_rng_state.data(), - max_seqlen, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, + te_cu_seqlens.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen, + is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), + at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -296,29 +248,18 @@ std::vector fused_attn_fwd_qkvpacked( // fused attention BWD with packed QKV std::vector fused_attn_bwd_qkvpacked( - size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens, - const at::Tensor QKV, - const at::Tensor O, - const at::Tensor dO, - const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, - const std::vector Aux_CTX_Tensors, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional descale_O, - const c10::optional descale_dO, - const c10::optional descale_dP, - const c10::optional scale_S, - const c10::optional scale_dP, - const c10::optional scale_dQKV, - c10::optional amax_dP, - c10::optional amax_dQKV) { + size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, + const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional descale_O, + const c10::optional descale_dO, const c10::optional descale_dP, + const c10::optional scale_S, const c10::optional scale_dP, + const c10::optional scale_dQKV, c10::optional amax_dP, + c10::optional amax_dQKV) { using namespace transformer_engine; auto qkv_sizes = QKV.sizes().vec(); @@ -340,50 +281,45 @@ std::vector fused_attn_bwd_qkvpacked( if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto d = q_shape[q_shape.size() - 1]; - if (set_zero - && ((h * d) % block_size == 0) - && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && ((h * d) % block_size == 0) && + (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { dQKV.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!descale_O.has_value()) || (!descale_dO.has_value()) - || (!descale_dP.has_value()) || (!scale_S.has_value()) - || (!scale_dP.has_value()) || (!scale_dQKV.has_value()) - || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || + (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || + (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || + (!amax_dQKV.has_value())) { std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, + descale_dO.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), + scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type, - amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + amax_dQKV.value().data_ptr(), + scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { // BF16 or FP16 - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, - qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, - dqkv_type, nullptr, nullptr, nullptr); + te_QKV = + makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_dO = + makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); + te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type, nullptr, nullptr, + nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -393,7 +329,7 @@ std::vector fused_attn_bwd_qkvpacked( nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); tensor->data.shape = std::vector(tmp.begin(), tmp.end()); @@ -403,16 +339,15 @@ std::vector fused_attn_bwd_qkvpacked( // create dBias the same shape as Bias at::Tensor dBias; TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) - && (bias_type != NVTE_ALIBI)) { + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if (nvte_aux_tensor_pack.size >= 2) { std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); dBias = torch::empty(bias_shape, options); te_dBias = makeTransformerEngineTensor(dBias); } else { - dBias = torch::empty({1, static_cast(h), - static_cast(max_seqlen), - static_cast(max_seqlen)}, options); + dBias = torch::empty({1, static_cast(h), static_cast(max_seqlen), + static_cast(max_seqlen)}, + options); te_dBias = makeTransformerEngineTensor(dBias); } } @@ -420,34 +355,32 @@ std::vector fused_attn_bwd_qkvpacked( // create cu_seqlens tensorwrappers auto cu_seqlens_sizes = cu_seqlens.sizes().vec(); std::vector cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()}; - TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, - DType::kInt32, nullptr, nullptr, nullptr); + TensorWrapper te_cu_seqlens = makeTransformerEngineTensor( + cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr); TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; - if ((seq_offsets_q.has_value()) - && (seq_offsets_k.has_value()) - && (seq_offsets_v.has_value()) - && (seq_offsets_o.has_value())) { - auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); - std::vector seq_offsets_q_shape{ - seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; - auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); - std::vector seq_offsets_k_shape{ - seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; - auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); - std::vector seq_offsets_v_shape{ - seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; - auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); - std::vector seq_offsets_o_shape{ - seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; - te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), - seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), - seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), - seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), - seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && + (seq_offsets_o.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; + te_seq_offsets_q = + makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = + makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = + makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = + makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape, + DType::kInt32, nullptr, nullptr, nullptr); } // create workspace @@ -455,51 +388,24 @@ std::vector fused_attn_bwd_qkvpacked( // populate tensors with appropriate shapes and dtypes nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQKV.data(), - te_dBias.data(), - te_cu_seqlens.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - max_seqlen, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, + te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_seq_offsets_q.data(), + te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), + at::cuda::getCurrentCUDAStream()); // allocate memory for workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // execute kernel nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQKV.data(), - te_dBias.data(), - te_cu_seqlens.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - max_seqlen, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, + te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_seq_offsets_q.data(), + te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), + at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -509,27 +415,17 @@ std::vector fused_attn_bwd_qkvpacked( // fused attention FWD with packed KV std::vector fused_attn_fwd_kvpacked( - size_t max_seqlen_q, size_t max_seqlen_kv, - bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor KV, - const transformer_engine::DType qkv_type, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional scale_S, - const c10::optional scale_O, - c10::optional amax_S, - c10::optional amax_O, - const c10::optional Bias, - const c10::optional rng_gen, - size_t rng_elts_per_thread) { + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional seq_offsets_o, + const c10::optional descale_QKV, const c10::optional descale_S, + const c10::optional scale_S, const c10::optional scale_O, + c10::optional amax_S, c10::optional amax_O, + const c10::optional Bias, const c10::optional rng_gen, + size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -549,89 +445,81 @@ std::vector fused_attn_fwd_kvpacked( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero - && ((h * d) % block_size == 0) - && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && ((h * d) % block_size == 0) && + (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { O.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!scale_S.has_value()) || (!scale_O.has_value()) - || (!amax_S.has_value()) || (!amax_O.has_value())) { + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || + (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), + scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, - qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_KV = + makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, - DType::kFloat32, nullptr, nullptr, nullptr); + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, + nullptr, nullptr, nullptr); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); + DType::kInt32, nullptr, nullptr, nullptr); te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - if ((seq_offsets_q.has_value()) - && (seq_offsets_k.has_value()) - && (seq_offsets_v.has_value()) - && (seq_offsets_o.has_value())) { - auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); - std::vector seq_offsets_q_shape{ - seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; - auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); - std::vector seq_offsets_k_shape{ - seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; - auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); - std::vector seq_offsets_v_shape{ - seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; - auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); - std::vector seq_offsets_o_shape{ - seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; - te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), - seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), - seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), - seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), - seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); + DType::kInt32, nullptr, nullptr, nullptr); + + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && + (seq_offsets_o.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; + te_seq_offsets_q = + makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = + makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = + makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = + makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape, + DType::kInt32, nullptr, nullptr, nullptr); } // extract rng seed and offset auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); + philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); // create auxiliary output tensors @@ -643,53 +531,40 @@ std::vector fused_attn_fwd_kvpacked( // populate tensors with appropriate shapes and dtypes nvte_fused_attn_fwd_kvpacked( - te_Q.data(), - te_KV.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - te_rng_state.data(), - max_seqlen_q, max_seqlen_kv, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), + te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), + te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; output_tensors.push_back(O); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors at::Tensor output_tensor; if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - output_tensor = (i < nvte_aux_tensor_pack.size-1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + if (i < nvte_aux_tensor_pack.size - 2) { + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + } else if (i == nvte_aux_tensor_pack.size - 2) { + output_tensor = rng_state; + } else if (i == nvte_aux_tensor_pack.size - 1) { + output_tensor = Bias.value(); } + } else { + output_tensor = (i < nvte_aux_tensor_pack.size - 1) + ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) + : rng_state; + } } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); } output_tensors.push_back(output_tensor); tensor->data.dptr = output_tensor.data_ptr(); @@ -697,24 +572,11 @@ std::vector fused_attn_fwd_kvpacked( // execute the kernel nvte_fused_attn_fwd_kvpacked( - te_Q.data(), - te_KV.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - te_rng_state.data(), - max_seqlen_q, max_seqlen_kv, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), + te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), + te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -725,32 +587,19 @@ std::vector fused_attn_fwd_kvpacked( // fused attention BWD with packed KV std::vector fused_attn_bwd_kvpacked( - size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor KV, - const at::Tensor O, - const at::Tensor dO, - const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, - const std::vector Aux_CTX_Tensors, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional descale_O, - const c10::optional descale_dO, - const c10::optional descale_dP, - const c10::optional scale_S, - const c10::optional scale_dP, - const c10::optional scale_dQKV, - c10::optional amax_dP, - c10::optional amax_dQKV) { + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, + const at::Tensor KV, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional descale_O, + const c10::optional descale_dO, const c10::optional descale_dP, + const c10::optional scale_S, const c10::optional scale_dP, + const c10::optional scale_dQKV, c10::optional amax_dP, + c10::optional amax_dQKV) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -776,61 +625,55 @@ std::vector fused_attn_bwd_kvpacked( TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero - && ((h_q * d)% block_size == 0) - && ((h_kv * d)% block_size == 0) - && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) && + (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { dQ.fill_(0); dKV.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!descale_O.has_value()) || (!descale_dO.has_value()) - || (!descale_dP.has_value()) || (!scale_S.has_value()) - || (!scale_dP.has_value()) || (!scale_dQKV.has_value()) - || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || + (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || + (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || + (!amax_dQKV.has_value())) { std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, + descale_dO.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), - descale_dP.value().data_ptr()); - te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, - amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), + scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); + te_dQ = + makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), + scale_dQKV.value().data_ptr(), nullptr); te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type, - amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + amax_dQKV.value().data_ptr(), + scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, - qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, - dqkv_type, nullptr, nullptr, nullptr); - te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, - dqkv_type, nullptr, nullptr, nullptr); + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_KV = + makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_dO = + makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); + te_dQ = + makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); + te_dKV = + makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -842,35 +685,33 @@ std::vector fused_attn_bwd_kvpacked( std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); + DType::kInt32, nullptr, nullptr, nullptr); te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + DType::kInt32, nullptr, nullptr, nullptr); TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; - if ((seq_offsets_q.has_value()) - && (seq_offsets_k.has_value()) - && (seq_offsets_v.has_value()) - && (seq_offsets_o.has_value())) { - auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); - std::vector seq_offsets_q_shape{ - seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; - auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); - std::vector seq_offsets_k_shape{ - seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; - auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); - std::vector seq_offsets_v_shape{ - seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; - auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); - std::vector seq_offsets_o_shape{ - seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; - te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), - seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), - seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), - seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), - seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && + (seq_offsets_o.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; + te_seq_offsets_q = + makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = + makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = + makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = + makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape, + DType::kInt32, nullptr, nullptr, nullptr); } // convert auxiliary tensors from forward to NVTETensors @@ -878,7 +719,7 @@ std::vector fused_attn_bwd_kvpacked( nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); tensor->data.shape = std::vector(tmp.begin(), tmp.end()); @@ -888,16 +729,15 @@ std::vector fused_attn_bwd_kvpacked( // create dBias the same shape as Bias at::Tensor dBias; TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) - && (bias_type != NVTE_ALIBI)) { + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if (nvte_aux_tensor_pack.size >= 2) { std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); dBias = torch::empty(bias_shape, options); te_dBias = makeTransformerEngineTensor(dBias); } else { - dBias = torch::empty({1, static_cast(h_q), - static_cast(max_seqlen_q), - static_cast(max_seqlen_kv)}, options); + dBias = torch::empty({1, static_cast(h_q), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv)}, + options); te_dBias = makeTransformerEngineTensor(dBias); } } @@ -906,58 +746,27 @@ std::vector fused_attn_bwd_kvpacked( TensorWrapper workspace; // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), - te_KV.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQ.data(), - te_dKV.data(), - te_dBias.data(), - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, + attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // execute kernel - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), - te_KV.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQ.data(), - te_dKV.data(), - te_dBias.data(), - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, + attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -967,28 +776,17 @@ std::vector fused_attn_bwd_kvpacked( // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( - size_t max_seqlen_q, size_t max_seqlen_kv, - bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor K, - const at::Tensor V, - const transformer_engine::DType qkv_type, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional scale_S, - const c10::optional scale_O, - c10::optional amax_S, - c10::optional amax_O, - const c10::optional Bias, - const c10::optional rng_gen, - size_t rng_elts_per_thread) { + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const at::Tensor Q, const at::Tensor K, const at::Tensor V, + const transformer_engine::DType qkv_type, const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional scale_S, + const c10::optional scale_O, c10::optional amax_S, + c10::optional amax_O, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -1009,94 +807,84 @@ std::vector fused_attn_fwd( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero - && ((h * d) % block_size == 0) - && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && ((h * d) % block_size == 0) && + (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { O.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!scale_S.has_value()) || (!scale_O.has_value()) - || (!amax_S.has_value()) || (!amax_O.has_value())) { + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || + (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), + scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, - qkv_type, nullptr, nullptr, nullptr); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, - qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); + te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, - DType::kFloat32, nullptr, nullptr, nullptr); + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, + nullptr, nullptr, nullptr); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); + DType::kInt32, nullptr, nullptr, nullptr); te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - if ((seq_offsets_q.has_value()) - && (seq_offsets_k.has_value()) - && (seq_offsets_v.has_value()) - && (seq_offsets_o.has_value())) { - auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); - std::vector seq_offsets_q_shape{ - seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; - auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); - std::vector seq_offsets_k_shape{ - seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; - auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); - std::vector seq_offsets_v_shape{ - seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; - auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); - std::vector seq_offsets_o_shape{ - seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; - te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), - seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), - seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), - seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), - seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); + DType::kInt32, nullptr, nullptr, nullptr); + + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && + (seq_offsets_o.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; + te_seq_offsets_q = + makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = + makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = + makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = + makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape, + DType::kInt32, nullptr, nullptr, nullptr); } // extract rng seed and offset auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); auto rng_state = torch::empty({2}, options); unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); + philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); // create auxiliary output tensors @@ -1107,81 +895,55 @@ std::vector fused_attn_fwd( TensorWrapper workspace; // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd( - te_Q.data(), - te_K.data(), - te_V.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - te_rng_state.data(), - max_seqlen_q, max_seqlen_kv, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, + bias_type, attn_mask_type, workspace.data(), + at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; output_tensors.push_back(O); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors at::Tensor output_tensor; if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - output_tensor = (i < nvte_aux_tensor_pack.size-1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + if (i < nvte_aux_tensor_pack.size - 2) { + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + } else if (i == nvte_aux_tensor_pack.size - 2) { + output_tensor = rng_state; + } else if (i == nvte_aux_tensor_pack.size - 1) { + output_tensor = Bias.value(); } + } else { + output_tensor = (i < nvte_aux_tensor_pack.size - 1) + ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) + : rng_state; + } } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); } output_tensors.push_back(output_tensor); tensor->data.dptr = output_tensor.data_ptr(); } // execute the kernel - nvte_fused_attn_fwd( - te_Q.data(), - te_K.data(), - te_V.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - te_rng_state.data(), - max_seqlen_q, max_seqlen_kv, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, + bias_type, attn_mask_type, workspace.data(), + at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -1192,33 +954,19 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( - size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor K, - const at::Tensor V, - const at::Tensor O, - const at::Tensor dO, - const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, - const std::vector Aux_CTX_Tensors, - const c10::optional seq_offsets_q, - const c10::optional seq_offsets_k, - const c10::optional seq_offsets_v, - const c10::optional seq_offsets_o, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional descale_O, - const c10::optional descale_dO, - const c10::optional descale_dP, - const c10::optional scale_S, - const c10::optional scale_dP, - const c10::optional scale_dQKV, - c10::optional amax_dP, - c10::optional amax_dQKV) { + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, + const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, + const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, + const std::vector Aux_CTX_Tensors, const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional descale_O, + const c10::optional descale_dO, const c10::optional descale_dP, + const c10::optional scale_S, const c10::optional scale_dP, + const c10::optional scale_dQKV, c10::optional amax_dP, + c10::optional amax_dQKV) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -1239,73 +987,79 @@ std::vector fused_attn_bwd( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; switch (layout_group) { - case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; - tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); - dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); - dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); - dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); - dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); - break; - case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; - tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); - dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); - dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); - dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); - dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); - break; - case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - dQ = torch::empty_like(Q, options); - tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; - tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); - dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); - dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); - dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); - break; - case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - dQ = torch::empty_like(Q, options); - tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; - tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); - dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); - dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); - dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); - break; - case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - dQ = torch::empty_like(Q, options); - dK = torch::empty_like(K, options); - dV = torch::empty_like(V, options); - break; - default: - NVTE_ERROR("QKV layout not supported!"); - } + case NVTE_QKV_Layout_Group::NVTE_3HD: + tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); + dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + break; + case NVTE_QKV_Layout_Group::NVTE_H3D: + tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); + dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + break; + case NVTE_QKV_Layout_Group::NVTE_HD_2HD: + dQ = torch::empty_like(Q, options); + tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); + dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + break; + case NVTE_QKV_Layout_Group::NVTE_HD_H2D: + dQ = torch::empty_like(Q, options); + tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); + dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + break; + case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + dQ = torch::empty_like(Q, options); + dK = torch::empty_like(K, options); + dV = torch::empty_like(V, options); + break; + default: + NVTE_ERROR("QKV layout not supported!"); + } // construct NVTE tensors TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero - && ((h_q * d) % block_size == 0) - && ((h_kv * d) % block_size == 0) - && dQ.is_contiguous() - && dK.is_contiguous() - && dV.is_contiguous() - && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) && + dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && + (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -1314,59 +1068,54 @@ std::vector fused_attn_bwd( dK.fill_(0); dV.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!descale_O.has_value()) || (!descale_dO.has_value()) - || (!descale_dP.has_value()) || (!scale_S.has_value()) - || (!scale_dP.has_value()) || (!scale_dQKV.has_value()) - || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || + (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || + (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || + (!amax_dQKV.has_value())) { std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, + descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, + descale_dO.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), - descale_dP.value().data_ptr()); - te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, - amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); - te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, - amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); - te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, - amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), + scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); + te_dQ = + makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), + scale_dQKV.value().data_ptr(), nullptr); + te_dK = + makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, amax_dQKV.value().data_ptr(), + scale_dQKV.value().data_ptr(), nullptr); + te_dV = + makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(), + scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, - qkv_type, nullptr, nullptr, nullptr); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, - qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, - dqkv_type, nullptr, nullptr, nullptr); - te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, - dqkv_type, nullptr, nullptr, nullptr); - te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, - dqkv_type, nullptr, nullptr, nullptr); + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); + te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_dO = + makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); + te_dQ = + makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); + te_dK = + makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, nullptr, nullptr, nullptr); + te_dV = + makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -1378,35 +1127,33 @@ std::vector fused_attn_bwd( std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); + DType::kInt32, nullptr, nullptr, nullptr); te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + DType::kInt32, nullptr, nullptr, nullptr); TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; - if ((seq_offsets_q.has_value()) - && (seq_offsets_k.has_value()) - && (seq_offsets_v.has_value()) - && (seq_offsets_o.has_value())) { - auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); - std::vector seq_offsets_q_shape{ - seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; - auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); - std::vector seq_offsets_k_shape{ - seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; - auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); - std::vector seq_offsets_v_shape{ - seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; - auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); - std::vector seq_offsets_o_shape{ - seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; - te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), - seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), - seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), - seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); - te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), - seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && + (seq_offsets_o.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; + te_seq_offsets_q = + makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = + makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = + makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = + makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape, + DType::kInt32, nullptr, nullptr, nullptr); } // convert auxiliary tensors from forward to NVTETensors @@ -1414,7 +1161,7 @@ std::vector fused_attn_bwd( nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); tensor->data.shape = std::vector(tmp.begin(), tmp.end()); @@ -1424,16 +1171,15 @@ std::vector fused_attn_bwd( // create dBias the same shape as Bias at::Tensor dBias; TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) - && (bias_type != NVTE_ALIBI)) { + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if (nvte_aux_tensor_pack.size >= 2) { std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); dBias = torch::empty(bias_shape, options); te_dBias = makeTransformerEngineTensor(dBias); } else { - dBias = torch::empty({1, static_cast(h_q), - static_cast(max_seqlen_q), - static_cast(max_seqlen_kv)}, options); + dBias = torch::empty({1, static_cast(h_q), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv)}, + options); te_dBias = makeTransformerEngineTensor(dBias); } } @@ -1442,62 +1188,27 @@ std::vector fused_attn_bwd( TensorWrapper workspace; // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd( - te_Q.data(), - te_K.data(), - te_V.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQ.data(), - te_dK.data(), - te_dV.data(), - te_dBias.data(), - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, workspace.data(), + at::cuda::getCurrentCUDAStream()); // allocate memory for workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // execute kernel - nvte_fused_attn_bwd( - te_Q.data(), - te_K.data(), - te_V.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQ.data(), - te_dK.data(), - te_dV.data(), - te_dBias.data(), - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_seq_offsets_q.data(), - te_seq_offsets_k.data(), - te_seq_offsets_v.data(), - te_seq_offsets_o.data(), - max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, workspace.data(), + at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -1514,173 +1225,142 @@ constexpr int load_size = warp_size * nvec; constexpr int block_size = 512; template -__launch_bounds__(block_size) -__global__ void prepare_kernel_fwd(const T *qkvi, - T *qkv, - const size_t B, - const size_t S, - const size_t Z, - const size_t W) { - const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; - const int id_in_warp = threadIdx.x % warp_size; - const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec; - const T *my_input = qkvi + offset_input; - - const size_t s = warpid / B; - if (s >= S) return; - - const size_t b = warpid % B; - - const size_t offset_output = blockIdx.y * B * S * Z * W + - (s + b * S) * W * Z + - id_in_warp * nvec; - - T *my_output = qkv + offset_output; - - for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size); - *out = *reinterpret_cast(my_input + i * load_size * 3); - } +__launch_bounds__(block_size) __global__ + void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z, + const size_t W) { + const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + const int id_in_warp = threadIdx.x % warp_size; + const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec; + const T *my_input = qkvi + offset_input; + + const size_t s = warpid / B; + if (s >= S) return; + + const size_t b = warpid % B; + + const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec; + + T *my_output = qkv + offset_output; + + for (int i = 0; i < Z; ++i) { + uint64_t *out = reinterpret_cast(my_output + i * load_size); + *out = *reinterpret_cast(my_input + i * load_size * 3); + } } template -__launch_bounds__(block_size) -__global__ void prepare_kernel_bwd(const T *q, const T *k, const T *v, - T *qkv, const size_t B, const size_t S, - const size_t Z, const size_t W) { - const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v); +__launch_bounds__(block_size) __global__ + void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B, + const size_t S, const size_t Z, const size_t W) { + const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v); - const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; - const int id_in_warp = threadIdx.x % warp_size; - const size_t offset_input = warpid * W * Z + id_in_warp * nvec; - const T *my_input = input + offset_input; + const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + const int id_in_warp = threadIdx.x % warp_size; + const size_t offset_input = warpid * W * Z + id_in_warp * nvec; + const T *my_input = input + offset_input; - const size_t b = warpid / S; - if (b >= B) return; + const size_t b = warpid / S; + if (b >= B) return; - const size_t s = warpid % S; + const size_t s = warpid % S; - const size_t offset_output = (b + s * B) * 3 * W * Z + - id_in_warp * nvec + blockIdx.y * W; + const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W; - T *my_output = qkv + offset_output; + T *my_output = qkv + offset_output; - for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size * 3); - *out = *reinterpret_cast(my_input + i * load_size); - } + for (int i = 0; i < Z; ++i) { + uint64_t *out = reinterpret_cast(my_output + i * load_size * 3); + *out = *reinterpret_cast(my_input + i * load_size); + } } } // namespace flash_attention at::Tensor fa_prepare_fwd(at::Tensor qkvi) { - NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half || - qkvi.scalar_type() == at::ScalarType::BFloat16); - NVTE_CHECK(qkvi.size(3) % flash_attention::load_size == 0); - NVTE_CHECK(qkvi.size(3) == flash_attention::load_size); - NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride."); - NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride."); - NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride."); - NVTE_CHECK(qkvi.stride(0) == 3 * qkvi.size(3) * qkvi.size(2) * qkvi.size(1), "Wrong stride."); - - // [s, b, n, h * 3] -> [3, b, s, n, h] - std::vector shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)}; - at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type())); - - size_t warps = qkvi.size(0) * qkvi.size(1); - size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; - size_t blocks = (warps + warps_per_block - 1) / warps_per_block; - dim3 grid(blocks, 3); - int threads = flash_attention::block_size; - if (qkvi.scalar_type() == at::ScalarType::Half) { - using dtype = at::Half; - flash_attention::prepare_kernel_fwd<<>>( - qkvi.data_ptr(), - qkv.data_ptr(), - shape[1], - shape[2], - shape[3], - shape[4]); - } else { - using dtype = at::BFloat16; - flash_attention::prepare_kernel_fwd<<>>( - qkvi.data_ptr(), - qkv.data_ptr(), - shape[1], - shape[2], - shape[3], - shape[4]); - } + NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half || + qkvi.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(qkvi.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(qkvi.size(3) == flash_attention::load_size); + NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride."); + NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride."); + NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride."); + NVTE_CHECK(qkvi.stride(0) == 3 * qkvi.size(3) * qkvi.size(2) * qkvi.size(1), "Wrong stride."); + + // [s, b, n, h * 3] -> [3, b, s, n, h] + std::vector shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)}; + at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type())); + + size_t warps = qkvi.size(0) * qkvi.size(1); + size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = flash_attention::block_size; + if (qkvi.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + flash_attention::prepare_kernel_fwd + <<>>( + qkvi.data_ptr(), qkv.data_ptr(), shape[1], shape[2], shape[3], shape[4]); + } else { + using dtype = at::BFloat16; + flash_attention::prepare_kernel_fwd + <<>>( + qkvi.data_ptr(), qkv.data_ptr(), shape[1], shape[2], shape[3], shape[4]); + } - return qkv; + return qkv; } at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { - NVTE_CHECK(q.is_contiguous()); - NVTE_CHECK(k.is_contiguous()); - NVTE_CHECK(v.is_contiguous()); - NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(q.scalar_type() == at::ScalarType::Half || - q.scalar_type() == at::ScalarType::BFloat16); - NVTE_CHECK(k.scalar_type() == q.scalar_type()); - NVTE_CHECK(v.scalar_type() == q.scalar_type()); - NVTE_CHECK(q.size(3) % flash_attention::load_size == 0); - NVTE_CHECK(q.size(3) == flash_attention::load_size); - NVTE_CHECK(k.size(3) % flash_attention::load_size == 0); - NVTE_CHECK(k.size(3) == flash_attention::load_size); - NVTE_CHECK(v.size(3) % flash_attention::load_size == 0); - NVTE_CHECK(v.size(3) == flash_attention::load_size); - - // 3 x [s, b, n, h] -> [b, s, n, 3 * h] - - std::vector shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)}; - at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type())); - - size_t warps = q.size(0) * q.size(1); - size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; - size_t blocks = (warps + warps_per_block - 1) / warps_per_block; - dim3 grid(blocks, 3); - int threads = flash_attention::block_size; - if (q.scalar_type() == at::ScalarType::Half) { - using dtype = at::Half; - flash_attention::prepare_kernel_bwd<<>>( - q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - qkv.data_ptr(), - q.size(0), - q.size(1), - q.size(2), - q.size(3)); - } else { - using dtype = at::BFloat16; - flash_attention::prepare_kernel_bwd<<>>( - q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - qkv.data_ptr(), - q.size(0), - q.size(1), - q.size(2), - q.size(3)); - } + NVTE_CHECK(q.is_contiguous()); + NVTE_CHECK(k.is_contiguous()); + NVTE_CHECK(v.is_contiguous()); + NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(q.scalar_type() == at::ScalarType::Half || + q.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(k.scalar_type() == q.scalar_type()); + NVTE_CHECK(v.scalar_type() == q.scalar_type()); + NVTE_CHECK(q.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(q.size(3) == flash_attention::load_size); + NVTE_CHECK(k.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(k.size(3) == flash_attention::load_size); + NVTE_CHECK(v.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(v.size(3) == flash_attention::load_size); + + // 3 x [s, b, n, h] -> [b, s, n, 3 * h] + + std::vector shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)}; + at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type())); + + size_t warps = q.size(0) * q.size(1); + size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = flash_attention::block_size; + if (q.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + flash_attention::prepare_kernel_bwd + <<>>( + q.data_ptr(), k.data_ptr(), v.data_ptr(), qkv.data_ptr(), + q.size(0), q.size(1), q.size(2), q.size(3)); + } else { + using dtype = at::BFloat16; + flash_attention::prepare_kernel_bwd + <<>>( + q.data_ptr(), k.data_ptr(), v.data_ptr(), qkv.data_ptr(), + q.size(0), q.size(1), q.size(2), q.size(3)); + } - return qkv; + return qkv; } /*************************************************************************************************** * Support THD format for Context Parallel: Binary search **************************************************************************************************/ -__forceinline__ -__device__ int binary_search(int target, int *array, int len) { +__forceinline__ __device__ int binary_search(int target, int *array, int len) { int left = 1, right = len - 1; while (left < right) { int mid = (left + right) / 2; @@ -1697,12 +1377,8 @@ __device__ int binary_search(int target, int *array, int len) { * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ -__global__ void thd_read_half_tensor_kernel(void *half, - void *tensor, - int *cu_seqlens, - int batch, - int hidden_size_in_bytes, - int half_idx, +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, int dim_size_of_token) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { @@ -1717,20 +1393,20 @@ __global__ void thd_read_half_tensor_kernel(void *half, int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; - half = reinterpret_cast(reinterpret_cast(half) + offset/2 * blockIdx.y); - tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); + half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); + tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; - float4* cur_half_token = reinterpret_cast(reinterpret_cast(half) + \ - offset_in_bytes); + float4 *cur_half_token = + reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); - offset_in_bytes = (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * \ - hidden_size_in_bytes; - float4* cur_token = reinterpret_cast(reinterpret_cast(tensor) + \ - offset_in_bytes); + offset_in_bytes = + (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; + float4 *cur_token = + reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { cur_half_token[idx] = cur_token[idx]; @@ -1738,8 +1414,7 @@ __global__ void thd_read_half_tensor_kernel(void *half, } } -at::Tensor thd_read_half_tensor(const at::Tensor &tensor, - const at::Tensor &cu_seqlens, +at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, int half_idx) { NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); @@ -1751,7 +1426,7 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, int seq_dim = tensor.dim() == 3 ? 0 : 1; int batch = cu_seqlens.size(0) - 1; - int num_heads = tensor.size(seq_dim + 1); + int num_heads = tensor.size(seq_dim + 1); int dim_per_head = tensor.size(seq_dim + 2); int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); @@ -1774,15 +1449,10 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, grid_y *= tensor.size(i); } dim3 grid = {grid_x, grid_y}; - thd_read_half_tensor_kernel<<>>( - half.data_ptr(), - tensor.data_ptr(), - cu_seqlens.data_ptr(), - batch, - hidden_size_in_bytes, - half_idx, - tensor.size(seq_dim)); + half.data_ptr(), tensor.data_ptr(), cu_seqlens.data_ptr(), batch, hidden_size_in_bytes, + half_idx, tensor.size(seq_dim)); return half; } @@ -1792,8 +1462,8 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, **************************************************************************************************/ template -__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, - int batch, int num_heads, int max_seqlen) { +__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, + int num_heads, int max_seqlen) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i] / 2; @@ -1820,8 +1490,8 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, } struct LseCorrectionFunctor { - __forceinline__ - __device__ static void run(double *lse, float *half_lse, size_t idx, size_t half_idx) { + __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, + size_t half_idx) { double val = lse[idx]; float val_per_step = half_lse[half_idx]; double max_scale = max(val, val_per_step); @@ -1830,10 +1500,8 @@ struct LseCorrectionFunctor { } }; -void thd_second_half_lse_correction(at::Tensor lse, - const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, - int total_tokens) { +void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, int total_tokens) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); @@ -1842,8 +1510,8 @@ void thd_second_half_lse_correction(at::Tensor lse, NVTE_CHECK(lse_per_step.dim() == 3); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch = lse.size(0); - int num_heads = lse.size(1); + int batch = lse.size(0); + int num_heads = lse.size(1); int max_seqlen = lse.size(2); NVTE_CHECK(lse_per_step.size(0) == batch); @@ -1855,33 +1523,28 @@ void thd_second_half_lse_correction(at::Tensor lse, unsigned int grid_x = (total_tokens / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; - thd_lse_kernel<<>>( - lse.data_ptr(), - lse_per_step.data_ptr(), - cu_seqlens.data_ptr(), - batch, - num_heads, - max_seqlen); + thd_lse_kernel + <<>>( + lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, + num_heads, max_seqlen); } struct ReadLseFunctor { - __forceinline__ - __device__ static void run(float *lse, float *half_lse, size_t idx, size_t half_idx) { + __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, + size_t half_idx) { half_lse[half_idx] = lse[idx]; } }; -at::Tensor thd_read_second_half_lse(const at::Tensor &lse, - const at::Tensor &cu_seqlens, +at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, int total_tokens) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse.dim() == 3); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch = lse.size(0); - int num_heads = lse.size(1); + int batch = lse.size(0); + int num_heads = lse.size(1); int max_seqlen = lse.size(2); NVTE_CHECK(cu_seqlens.size(0) == batch + 1); @@ -1893,14 +1556,10 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, unsigned int grid_x = (total_tokens / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; - thd_lse_kernel<<>>( - lse.data_ptr(), - half_lse.data_ptr(), - cu_seqlens.data_ptr(), - batch, - num_heads, - max_seqlen); + thd_lse_kernel + <<>>( + lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, + num_heads, max_seqlen); return half_lse; } @@ -1910,15 +1569,9 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, **************************************************************************************************/ template -__global__ void thd_out_correction_kernel(dtype *out, - dtype *out_per_step, - float *lse, - float *lse_per_step, - int *cu_seqlens, - int batch, - int num_heads, - int dim_per_head, - int max_seqlen) { +__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, + float *lse_per_step, int *cu_seqlens, int batch, + int num_heads, int dim_per_head, int max_seqlen) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); @@ -1950,24 +1603,22 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *cur_out_per_step = out_per_step + idx_per_step; for (int j = lane_id; j < num_loops_per_head; j += tile_size) { - float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; - float4 data = reinterpret_cast(cur_out)[j]; - dtype *p_per_step = reinterpret_cast(&data_per_step); - dtype *p = reinterpret_cast(&data); + float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; + float4 data = reinterpret_cast(cur_out)[j]; + dtype *p_per_step = reinterpret_cast(&data_per_step); + dtype *p = reinterpret_cast(&data); for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { p[k] += p_per_step[k] * lse_corrected_exp; } - reinterpret_cast(cur_out)[j] = data; + reinterpret_cast(cur_out)[j] = data; } } } } -template -static void thd_out_correction_helper(at::Tensor out, - const at::Tensor &out_per_step, - const at::Tensor &lse, - const at::Tensor &lse_per_step, +template +static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, + const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens) { NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type()); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); @@ -1975,9 +1626,9 @@ static void thd_out_correction_helper(at::Tensor out, NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); int total_tokens = out.size(0); - int num_heads = out.size(1); + int num_heads = out.size(1); int dim_per_head = out.size(2); - int batch = lse.size(0); + int batch = lse.size(0); int max_seqlen = lse.size(2); NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1)); @@ -1991,28 +1642,19 @@ static void thd_out_correction_helper(at::Tensor out, constexpr int tile = 16; constexpr int block = 512; - unsigned int grid_x = (static_cast(total_tokens) / (only_second_half + 1) * \ - tile + block - 1) / block; + unsigned int grid_x = + (static_cast(total_tokens) / (only_second_half + 1) * tile + block - 1) / block; dim3 grid = {grid_x, (unsigned int)num_heads}; - thd_out_correction_kernel<<>>( - out.data_ptr(), - out_per_step.data_ptr(), - lse.data_ptr(), - lse_per_step.data_ptr(), - cu_seqlens.data_ptr(), - batch, - num_heads, - dim_per_head, - max_seqlen); + thd_out_correction_kernel + <<>>( + out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), + lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, + dim_per_head, max_seqlen); } -void thd_out_correction(at::Tensor out, - const at::Tensor &out_per_step, - const at::Tensor &lse, - const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, +void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, + const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, bool only_second_half) { if (only_second_half) { if (out.scalar_type() == at::ScalarType::Half) { @@ -2048,12 +1690,8 @@ void thd_out_correction(at::Tensor out, **************************************************************************************************/ template -__global__ void thd_grad_correction_kernel(dtype *grad, - dtype *grad_per_step, - int *cu_seqlens, - int batch, - int hidden_size, - int dim_size_of_token) { +__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, + int batch, int hidden_size, int dim_size_of_token) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { if constexpr (functor_idx < 2) { @@ -2105,39 +1743,35 @@ __global__ void thd_grad_correction_kernel(dtype *grad, } struct EmptyFunctor { - __forceinline__ - __device__ static void run(void *token, void *token_per_step, int idx) {} + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} }; struct CopyFunctor { - __forceinline__ - __device__ static void run(void *token, void *token_per_step, int idx) { - reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { + reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; } }; template struct AddFunctor { - __forceinline__ - __device__ static void run(dtype *token, dtype *token_per_step, int idx) { - float4 d_ = reinterpret_cast(token)[idx]; - dtype *p_ = reinterpret_cast(&d_); + __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { + float4 d_ = reinterpret_cast(token)[idx]; + dtype *p_ = reinterpret_cast(&d_); - float4 d = reinterpret_cast(token_per_step)[idx]; - dtype *p = reinterpret_cast(&d); + float4 d = reinterpret_cast(token_per_step)[idx]; + dtype *p = reinterpret_cast(&d); - #pragma unroll +#pragma unroll for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { p_[i] += p[i]; } - reinterpret_cast(token)[idx] = d_; + reinterpret_cast(token)[idx] = d_; } }; template -static void thd_grad_correction_helper(at::Tensor grad, - const at::Tensor &grad_per_step, +static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step, const at::Tensor &cu_seqlens) { NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); @@ -2148,7 +1782,7 @@ static void thd_grad_correction_helper(at::Tensor grad, int seq_dim = grad.dim() == 3 ? 0 : 1; int total_tokens = grad.size(seq_dim); - int num_heads = grad.size(seq_dim + 1); + int num_heads = grad.size(seq_dim + 1); int dim_per_head = grad.size(seq_dim + 2); int batch = cu_seqlens.size(0) - 1; @@ -2177,48 +1811,40 @@ static void thd_grad_correction_helper(at::Tensor grad, dim3 grid = {grid_x, grid_y}; thd_grad_correction_kernel - <<>>( - grad.data_ptr(), - grad_per_step.data_ptr(), - cu_seqlens.data_ptr(), - batch, - hidden_size, - total_tokens); + <<>>( + grad.data_ptr(), grad_per_step.data_ptr(), cu_seqlens.data_ptr(), + batch, hidden_size, total_tokens); } template -static void thd_grad_dispatcher(at::Tensor grad, - const at::Tensor &grad_per_step, - const at::Tensor &cu_seqlens, - const std::string &first_half, +static void thd_grad_dispatcher(at::Tensor grad, const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, const std::string &first_half, const std::string &second_half) { if (first_half == "add" && second_half == "none") { - thd_grad_correction_helper, EmptyFunctor, 0>( - grad, grad_per_step, cu_seqlens); + thd_grad_correction_helper, EmptyFunctor, 0>(grad, grad_per_step, + cu_seqlens); } else if (first_half == "copy" && second_half == "none") { - thd_grad_correction_helper( - grad, grad_per_step, cu_seqlens); + thd_grad_correction_helper(grad, grad_per_step, + cu_seqlens); } else if (first_half == "none" && second_half == "add") { - thd_grad_correction_helper, 1>( - grad, grad_per_step, cu_seqlens); + thd_grad_correction_helper, 1>(grad, grad_per_step, + cu_seqlens); } else if (first_half == "none" && second_half == "copy") { - thd_grad_correction_helper( - grad, grad_per_step, cu_seqlens); + thd_grad_correction_helper(grad, grad_per_step, + cu_seqlens); } else if (first_half == "add" && second_half == "copy") { - thd_grad_correction_helper, CopyFunctor, 2>( - grad, grad_per_step, cu_seqlens); + thd_grad_correction_helper, CopyFunctor, 2>(grad, grad_per_step, + cu_seqlens); } else if (first_half == "copy" && second_half == "add") { - thd_grad_correction_helper, 2>( - grad, grad_per_step, cu_seqlens); + thd_grad_correction_helper, 2>(grad, grad_per_step, + cu_seqlens); } else { NVTE_ERROR("Unsupported Functor of first half and second_half\n"); } } -void thd_grad_correction(at::Tensor grad, - const at::Tensor &grad_per_step, - const at::Tensor &cu_seqlens, - const std::string &first_half, +void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, const std::string &first_half, const std::string &second_half) { if (grad.scalar_type() == at::ScalarType::Half) { thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); @@ -2235,18 +1861,14 @@ void thd_grad_correction(at::Tensor grad, * Support THD format for Context Parallel: Generate partitioned indices for input tokens **************************************************************************************************/ -__global__ void thd_partition_indices_kernel(int *output, - int *cu_seqlens, - int batch, - int total_tokens, - int world_size, - int rank) { +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { int seqlen = cu_seqlens[i]; // Currently we assume that each sequence length is divisible by (world_size*2) since we have // to distribute each sequence evenly to different GPUs. - assert(seqlen % (world_size*2) == 0); + assert(seqlen % (world_size * 2) == 0); cu_seqlens_s[i] = seqlen / world_size; } __syncthreads(); @@ -2258,16 +1880,14 @@ __global__ void thd_partition_indices_kernel(int *output, int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; int index = token_id - cu_seqlens_s[seq_id]; - int offset = index < seq_len/2 ? rank : (world_size-1) * 2 - rank; + int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; output[token_id] = index; } } -at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, - int total_tokens, - int world_size, - int rank) { +at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, + int world_size, int rank) { NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.size(0) >= 2); @@ -2282,14 +1902,9 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, constexpr unsigned int block = 256; unsigned int grid = (output.size(0) + block - 1) / block; - thd_partition_indices_kernel<<>>( - output.data_ptr(), - cu_seqlens.data_ptr(), - batch, - total_tokens, - world_size, - rank); + output.data_ptr(), cu_seqlens.data_ptr(), batch, total_tokens, world_size, rank); return output; } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index c798a39df53066de49c88967ceb1052bfd486d37..c783c9d9882270429bf7174a050d043555ff5fe7 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -6,74 +6,53 @@ #include "extensions.h" +at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, + at::Tensor scale_inv, transformer_engine::DType otype) { + using namespace transformer_engine; + auto input_shape = input.sizes().vec(); + std::vector shape{input_shape.begin(), input_shape.end()}; -at::Tensor cast_to_fp8(const at::Tensor &input, - const at::Tensor &scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; + auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + if (input.numel() == 0) return output; - if (input.numel() == 0) - return output; + auto input_cu = makeTransformerEngineTensor(input); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - nvte_fp8_quantize(input_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); - - return output; + return output; } +void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype) { + using namespace transformer_engine; + size_t N = static_cast(input.size(0)); + size_t H = static_cast(input.size(1)); -void cast_to_fp8_noalloc(const at::Tensor &input, - const at::Tensor &scale, - at::Tensor output, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + auto input_cu = makeTransformerEngineTensor(input); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); - nvte_fp8_quantize(input_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - return; + return; } +at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, + transformer_engine::DType itype, transformer_engine::DType otype) { + using namespace transformer_engine; + auto input_shape = input.sizes().vec(); + std::vector shape{input_shape.begin(), input_shape.end()}; -at::Tensor cast_from_fp8(const at::Tensor &input, - const at::Tensor &scale_inv, - transformer_engine::DType itype, - transformer_engine::DType otype -) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; - - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, - nullptr, nullptr, scale_inv.data_ptr()); - auto output_cu = makeTransformerEngineTensor(output); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, + scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output); - nvte_fp8_dequantize(input_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - return output; + return output; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 5f2757751e2df2dd5cb7e5559836ecb91e8535ff..19cd4675bc739b95437c8b615a8f5d1fbd7e98cd 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -6,150 +6,74 @@ #include "extensions.h" - -void te_gemm(at::Tensor A, - at::Tensor A_scale_inverse, - transformer_engine::DType A_type, - bool transa, - at::Tensor B, - at::Tensor B_scale_inverse, - transformer_engine::DType B_type, - bool transb, - at::Tensor D, - at::Tensor D_scale, - transformer_engine::DType D_type, - at::Tensor D_amax, - at::Tensor bias, - transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, - bool grad, - at::Tensor workspace, - size_t workspaceSize, - bool accumulate, - bool use_split_accumulator, - int math_sm_count -) { +void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, + bool transa, at::Tensor B, at::Tensor B_scale_inverse, + transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { using namespace transformer_engine; - auto te_A = makeTransformerEngineTensor(A.data_ptr(), - {static_cast(A.size(0)), - static_cast(A.size(1))}, - A_type, nullptr, nullptr, - A_scale_inverse.data_ptr()); - auto te_B = makeTransformerEngineTensor(B.data_ptr(), - {static_cast(B.size(0)), - static_cast(B.size(1))}, - B_type, nullptr, nullptr, - B_scale_inverse.data_ptr()); - auto te_D = makeTransformerEngineTensor(D.data_ptr(), - {static_cast(D.size(0)), - static_cast(D.size(1))}, - D_type, D_amax.data_ptr(), - D_scale.data_ptr(), nullptr); - auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, - bias_type); + auto te_A = makeTransformerEngineTensor( + A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, + nullptr, nullptr, A_scale_inverse.data_ptr()); + auto te_B = makeTransformerEngineTensor( + B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, + nullptr, nullptr, B_scale_inverse.data_ptr()); + auto te_D = makeTransformerEngineTensor( + D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, + D_amax.data_ptr(), D_scale.data_ptr(), nullptr); + auto te_bias = + makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, bias_type); const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; - auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(), - gelu_shape, - GetTransformerEngineDType( - pre_gelu_out.scalar_type())); - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - {workspaceSize}, - DType::kByte); + ? std::vector{static_cast(pre_gelu_out.size(0))} + : std::vector{static_cast(pre_gelu_out.size(0)), + static_cast(pre_gelu_out.size(1))}; + auto te_pre_gelu_out = makeTransformerEngineTensor( + pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); - nvte_cublas_gemm(te_A.data(), - te_B.data(), - te_D.data(), - te_bias.data(), - te_pre_gelu_out.data(), - transa, - transb, - grad, - te_workspace.data(), - accumulate, - use_split_accumulator, - math_sm_count, - at::cuda::getCurrentCUDAStream()); + nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), + transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); } -void te_atomic_gemm(at::Tensor A, - at::Tensor A_scale_inverse, - transformer_engine::DType A_type, - bool transa, - at::Tensor B, - at::Tensor B_scale_inverse, - transformer_engine::DType B_type, - bool transb, - at::Tensor D, - at::Tensor D_scale, - transformer_engine::DType D_type, - at::Tensor D_amax, - at::Tensor bias, - transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, - bool grad, - at::Tensor workspace, - size_t workspaceSize, - bool accumulate, - bool use_split_accumulator, - int math_sm_count, - int m_split, - int n_split, - bool gemm_producer, - at::Tensor counter -) { +void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, + bool transa, at::Tensor B, at::Tensor B_scale_inverse, + transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count, int m_split, int n_split, + bool gemm_producer, at::Tensor counter) { using namespace transformer_engine; - auto te_A = makeTransformerEngineTensor(A.data_ptr(), - {static_cast(A.size(0)), - static_cast(A.size(1))}, - A_type, nullptr, nullptr, - A_scale_inverse.data_ptr()); - auto te_B = makeTransformerEngineTensor(B.data_ptr(), - {static_cast(B.size(0)), - static_cast(B.size(1))}, - B_type, nullptr, nullptr, - B_scale_inverse.data_ptr()); - auto te_D = makeTransformerEngineTensor(D.data_ptr(), - {static_cast(D.size(0)), - static_cast(D.size(1))}, - D_type, D_amax.data_ptr(), - D_scale.data_ptr(), nullptr); - auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, - bias_type); - auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), - {static_cast(counter.size(0))}, - DType::kInt32); + auto te_A = makeTransformerEngineTensor( + A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, + nullptr, nullptr, A_scale_inverse.data_ptr()); + auto te_B = makeTransformerEngineTensor( + B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, + nullptr, nullptr, B_scale_inverse.data_ptr()); + auto te_D = makeTransformerEngineTensor( + D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, + D_amax.data_ptr(), D_scale.data_ptr(), nullptr); + auto te_bias = + makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, bias_type); + auto te_counter = makeTransformerEngineTensor( + counter.data_ptr(), {static_cast(counter.size(0))}, DType::kInt32); const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; - auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(), - gelu_shape, - GetTransformerEngineDType( - pre_gelu_out.scalar_type())); - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - {workspaceSize}, - DType::kByte); + ? std::vector{static_cast(pre_gelu_out.size(0))} + : std::vector{static_cast(pre_gelu_out.size(0)), + static_cast(pre_gelu_out.size(1))}; + auto te_pre_gelu_out = makeTransformerEngineTensor( + pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); - nvte_cublas_atomic_gemm(te_A.data(), - te_B.data(), - te_D.data(), - te_bias.data(), - te_pre_gelu_out.data(), - transa, - transb, - grad, - te_workspace.data(), - accumulate, - use_split_accumulator, - math_sm_count, - m_split, - n_split, - gemm_producer, - te_counter.data(), - at::cuda::getCurrentCUDAStream()); + nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), + te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), + accumulate, use_split_accumulator, math_sm_count, m_split, n_split, + gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cu b/transformer_engine/pytorch/csrc/extensions/misc.cu index dd15369bf5a352a8cf55246bfa79d96590c4a06c..30f0aa9533b8ebc06d59eb176cfdd7ca971598bb 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cu +++ b/transformer_engine/pytorch/csrc/extensions/misc.cu @@ -9,12 +9,8 @@ #include "comm_gemm_overlap.h" #endif // NVTE_WITH_USERBUFFERS -size_t get_cublasLt_version() { - return cublasLtGetVersion(); -} +size_t get_cublasLt_version() { return cublasLtGetVersion(); } -size_t get_cudnn_version() { - return cudnnGetVersion(); -} +size_t get_cudnn_version() { return cudnnGetVersion(); } void placeholder() {} // TODO(ksivamani) clean this up diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu index 7c9b91dcabd2e541aa4c28e05e4c2298fc7665eb..3009a827684721b2b0dc814723f889d34abfcb25 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu @@ -35,9 +35,9 @@ template struct SGDFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata& tl, // NOLINT(*) - float wd, float momentum, - float dampening, float lr, bool nesterov, - bool first_run, bool wd_after_momentum, float scale) { + float wd, float momentum, float dampening, float lr, + bool nesterov, bool first_run, bool wd_after_momentum, + float scale) { // Early exit if we don't need to do anything if (*noop_gmem) return; @@ -176,24 +176,21 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, // } // Case 2. fp32, fp32, fp32, No else if (grad_type == at::ScalarType::Float && // NOLINT(*) - weight_type == at::ScalarType::Float && - num_tensors == 3) { + weight_type == at::ScalarType::Float && num_tensors == 3) { multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov, first_run, wd_after_momentum, scale); } // Case 3. fp16, fp32, fp32, Yes else if (grad_type == at::ScalarType::Half && // NOLINT(*) - weight_type == at::ScalarType::Float && - num_tensors == 4) { + weight_type == at::ScalarType::Float && num_tensors == 4) { multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov, first_run, wd_after_momentum, scale); } // Case 4. fp32, fp32, fp32, Yes else if (grad_type == at::ScalarType::Float && // NOLINT(*) - weight_type == at::ScalarType::Float && - num_tensors == 4) { + weight_type == at::ScalarType::Float && num_tensors == 4) { multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov, first_run, wd_after_momentum, scale); diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cu b/transformer_engine/pytorch/csrc/extensions/normalization.cu index d51635bb6f203255178c57eb984f98b06abf74c3..77bbcbc9d66bede018c745eea0fb6c034b4ddc7d 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cu +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cu @@ -6,449 +6,316 @@ #include "extensions.h" -std::vector layernorm_bwd(const at::Tensor &dz, - const at::Tensor &x, - const at::Tensor &mu, - const at::Tensor &rsigma, - const at::Tensor &gamma, - const int sm_margin, - const bool zero_centered_gamma -) { - auto dx = at::empty_like(x); - auto dgamma = at::empty_like(gamma); - auto dbeta = at::empty_like(gamma); - transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part; - - auto dz_cu = makeTransformerEngineTensor(dz); - auto x_cu = makeTransformerEngineTensor(x); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - auto gamma_cu = makeTransformerEngineTensor(gamma); - auto dx_cu = makeTransformerEngineTensor(dx); - auto dgamma_cu = makeTransformerEngineTensor(dgamma); - auto dbeta_cu = makeTransformerEngineTensor(dbeta); - - // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), - dbeta_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - // Alloc space for Tensors. - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); - auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype()); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), - barrier.shape(), - barrier.dtype()); - dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), - dgamma_part.shape(), - dgamma_part.dtype()); - dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), - dbeta_part.shape(), - dbeta_part.dtype()); - - // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), - dbeta_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - return { dx, dgamma, dbeta }; +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &mu, const at::Tensor &rsigma, + const at::Tensor &gamma, const int sm_margin, + const bool zero_centered_gamma) { + auto dx = at::empty_like(x); + auto dgamma = at::empty_like(gamma); + auto dbeta = at::empty_like(gamma); + transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part; + + auto dz_cu = makeTransformerEngineTensor(dz); + auto x_cu = makeTransformerEngineTensor(x); + auto mu_cu = makeTransformerEngineTensor(mu); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + auto gamma_cu = makeTransformerEngineTensor(gamma); + auto dx_cu = makeTransformerEngineTensor(dx); + auto dgamma_cu = makeTransformerEngineTensor(dgamma); + auto dbeta_cu = makeTransformerEngineTensor(dbeta); + + // This call populates tensors with the required config. + const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; + bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), + at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), + barrier.data()); + + // Alloc space for Tensors. + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); + auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); + auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); + dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(), + dgamma_part.dtype()); + dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), dbeta_part.shape(), + dbeta_part.dtype()); + + // Actual call to bwd kernel. + bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), + at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), + barrier.data()); + + return {dx, dgamma, dbeta}; } - -std::vector layernorm_fwd_fp8(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset, - const int amax_offset, - const int scale_inv_offset -) { - using namespace transformer_engine; - - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); - return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, - scale, ln_out, amax, scale_inv, - otype, sm_margin, zero_centered_gamma, - scale_offset, amax_offset, scale_inv_offset); +std::vector layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, float eps, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma, const int scale_offset, + const int amax_offset, const int scale_inv_offset) { + using namespace transformer_engine; + + auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); + return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, scale, ln_out, amax, scale_inv, otype, + sm_margin, zero_centered_gamma, scale_offset, amax_offset, + scale_inv_offset); } - -std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - at::Tensor scale, - at::Tensor ln_out, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset, - const int amax_offset, - const int scale_inv_offset -) { - using namespace transformer_engine; - - // Choose kernel implementation - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - - // Tensor dimensions - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto beta_cu = makeTransformerEngineTensor(bias); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), - {N, H}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - - // Query workspace sizes - transformer_engine::TensorWrapper workspace, barrier; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - // Allocate workspaces - auto workspace_data = allocateSpace(workspace.shape(), - workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), - barrier.dtype(), - true); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), - barrier.shape(), - barrier.dtype()); - - // Launch kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - return {ln_out, mu, rsigma}; +std::vector layernorm_fwd_fp8_noalloc( + const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, + at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma, + const int scale_offset, const int amax_offset, const int scale_inv_offset) { + using namespace transformer_engine; + + // Choose kernel implementation + const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + + // Tensor dimensions + size_t N = static_cast(input.size(0)); + size_t H = static_cast(input.size(1)); + + // Get pointers for FP8 scale, amax, scale-inverse + void *scale_dptr = getDataPtr(scale, scale_offset); + void *amax_dptr = getDataPtr(amax, amax_offset); + void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + + // Construct Transformer Engine tensors + DType itype = GetTransformerEngineDType(input.scalar_type()); + auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto input_cu = makeTransformerEngineTensor(input); + auto gamma_cu = makeTransformerEngineTensor(weight); + auto beta_cu = makeTransformerEngineTensor(bias); + auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, + scale_inv_dptr); + auto mu_cu = makeTransformerEngineTensor(mu); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + + // Query workspace sizes + transformer_engine::TensorWrapper workspace, barrier; + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), + rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), + barrier.data()); + + // Allocate workspaces + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); + + // Launch kernel + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), + rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), + barrier.data()); + + return {ln_out, mu, rsigma}; } - -at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset, - const int amax_offset, - const int scale_inv_offset +at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, float eps, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma, const int scale_offset, + const int amax_offset, const int scale_inv_offset ) { - // This is a specialized version of layernorm_fwd_fp8, optimized for inference, - // which only returns the normalized output. - std::vector out = layernorm_fwd_fp8( - input, weight, bias, eps, - scale, amax, scale_inv, - otype, sm_margin, zero_centered_gamma, - scale_offset, amax_offset, scale_inv_offset); - return out[0]; + // This is a specialized version of layernorm_fwd_fp8, optimized for inference, + // which only returns the normalized output. + std::vector out = + layernorm_fwd_fp8(input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, + zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); + return out[0]; } +std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, float eps, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine; -std::vector layernorm_fwd(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - const int sm_margin, - const bool zero_centered_gamma -) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); + DType itype = GetTransformerEngineDType(input.scalar_type()); + auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); - return layernorm_fwd_noalloc(input, weight, bias, ln_out, eps, - sm_margin, zero_centered_gamma); + return layernorm_fwd_noalloc(input, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma); } +std::vector layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, at::Tensor ln_out, float eps, + const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine; -std::vector layernorm_fwd_noalloc(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - at::Tensor ln_out, - float eps, - const int sm_margin, - const bool zero_centered_gamma -) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); + DType itype = GetTransformerEngineDType(input.scalar_type()); - return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), - ln_out, at::Tensor(), at::Tensor(), - itype, sm_margin, zero_centered_gamma); + return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), ln_out, at::Tensor(), + at::Tensor(), itype, sm_margin, zero_centered_gamma); } - -at::Tensor layernorm_fwd_inf(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - const int sm_margin, - const bool zero_centered_gamma -) { - // This is a specialized version of layernorm_fwd, optimized for inference, - // which only returns the normalized output. - std::vector out = layernorm_fwd(input, weight, bias, eps, sm_margin, - zero_centered_gamma); - return out[0]; +at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, float eps, const int sm_margin, + const bool zero_centered_gamma) { + // This is a specialized version of layernorm_fwd, optimized for inference, + // which only returns the normalized output. + std::vector out = + layernorm_fwd(input, weight, bias, eps, sm_margin, zero_centered_gamma); + return out[0]; } -std::vector rmsnorm_bwd(const at::Tensor &dz, - const at::Tensor &x, - const at::Tensor &rsigma, - const at::Tensor &gamma, - const int sm_margin, - const bool zero_centered_gamma -) { - auto dx = at::empty_like(x); - auto dgamma = at::empty_like(gamma); - transformer_engine::TensorWrapper workspace, barrier, dgamma_part; - - auto dz_cu = makeTransformerEngineTensor(dz); - auto x_cu = makeTransformerEngineTensor(x); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - auto gamma_cu = makeTransformerEngineTensor(gamma); - auto dx_cu = makeTransformerEngineTensor(dx); - auto dgamma_cu = makeTransformerEngineTensor(dgamma); - - // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dgamma_part.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - // Alloc space for Tensors. - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), - barrier.shape(), - barrier.dtype()); - dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), - dgamma_part.shape(), - dgamma_part.dtype()); - - // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dgamma_part.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - return { dx, dgamma }; +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &rsigma, const at::Tensor &gamma, + const int sm_margin, const bool zero_centered_gamma) { + auto dx = at::empty_like(x); + auto dgamma = at::empty_like(gamma); + transformer_engine::TensorWrapper workspace, barrier, dgamma_part; + + auto dz_cu = makeTransformerEngineTensor(dz); + auto x_cu = makeTransformerEngineTensor(x); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + auto gamma_cu = makeTransformerEngineTensor(gamma); + auto dx_cu = makeTransformerEngineTensor(dx); + auto dgamma_cu = makeTransformerEngineTensor(dgamma); + + // This call populates tensors with the required config. + const auto bwd_fun = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd; + bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), + barrier.data()); + + // Alloc space for Tensors. + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); + auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); + dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(), + dgamma_part.dtype()); + + // Actual call to bwd kernel. + bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), + barrier.data()); + + return {dx, dgamma}; } - -std::vector rmsnorm_fwd_fp8(const at::Tensor &input, - const at::Tensor &weight, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset, - const int amax_offset, - const int scale_inv_offset -) { - using namespace transformer_engine; - - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); - return rmsnorm_fwd_fp8_noalloc(input, weight, eps, - scale, ln_out, amax, scale_inv, - otype, sm_margin, zero_centered_gamma, - scale_offset, amax_offset, scale_inv_offset); +std::vector rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, + float eps, at::Tensor scale, at::Tensor amax, + at::Tensor scale_inv, transformer_engine::DType otype, + const int sm_margin, const bool zero_centered_gamma, + const int scale_offset, const int amax_offset, + const int scale_inv_offset) { + using namespace transformer_engine; + + auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); + return rmsnorm_fwd_fp8_noalloc(input, weight, eps, scale, ln_out, amax, scale_inv, otype, + sm_margin, zero_centered_gamma, scale_offset, amax_offset, + scale_inv_offset); } - -std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, - const at::Tensor &weight, - float eps, - at::Tensor scale, - at::Tensor ln_out, - at::Tensor amax, - at::Tensor scale_inv, +std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const at::Tensor &weight, + float eps, at::Tensor scale, at::Tensor ln_out, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset, - const int amax_offset, - const int scale_inv_offset -) { - using namespace transformer_engine; - - // Choose kernel implementation - const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd; - - // Tensor dimensions - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), - {N, H}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - - // Query workspace sizes - transformer_engine::TensorWrapper workspace, barrier; - func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), - rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - // Allocate workspaces - auto workspace_data = allocateSpace(workspace.shape(), - workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), - barrier.dtype(), - true); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), - barrier.shape(), - barrier.dtype()); - - // Launch kernel - func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), - rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - return {ln_out, rsigma}; + const int sm_margin, const bool zero_centered_gamma, + const int scale_offset, const int amax_offset, + const int scale_inv_offset) { + using namespace transformer_engine; + + // Choose kernel implementation + const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd; + + // Tensor dimensions + size_t N = static_cast(input.size(0)); + size_t H = static_cast(input.size(1)); + + // Get pointers for FP8 scale, amax, scale-inverse + void *scale_dptr = getDataPtr(scale, scale_offset); + void *amax_dptr = getDataPtr(amax, amax_offset); + void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + + // Construct Transformer Engine tensors + DType itype = GetTransformerEngineDType(input.scalar_type()); + auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto input_cu = makeTransformerEngineTensor(input); + auto gamma_cu = makeTransformerEngineTensor(weight); + auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, + scale_inv_dptr); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + + // Query workspace sizes + transformer_engine::TensorWrapper workspace, barrier; + func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), + barrier.data()); + + // Allocate workspaces + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); + + // Launch kernel + func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), + barrier.data()); + + return {ln_out, rsigma}; } - -at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, - const at::Tensor &weight, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset, - const int amax_offset, - const int scale_inv_offset -) { - // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, - // which only returns the normalized output. - std::vector out = rmsnorm_fwd_fp8( - input, weight, eps, - scale, amax, scale_inv, - otype, sm_margin, zero_centered_gamma, - scale_offset, amax_offset, scale_inv_offset); - return out[0]; +at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps, + at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma, const int scale_offset, + const int amax_offset, const int scale_inv_offset) { + // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, + // which only returns the normalized output. + std::vector out = + rmsnorm_fwd_fp8(input, weight, eps, scale, amax, scale_inv, otype, sm_margin, + zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); + return out[0]; } +std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, + const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine; -std::vector rmsnorm_fwd(const at::Tensor &input, - const at::Tensor &weight, - float eps, - const int sm_margin, - const bool zero_centered_gamma -) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); + DType itype = GetTransformerEngineDType(input.scalar_type()); + auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); - return rmsnorm_fwd_noalloc(input, weight, ln_out, eps, - sm_margin, zero_centered_gamma); + return rmsnorm_fwd_noalloc(input, weight, ln_out, eps, sm_margin, zero_centered_gamma); } +std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, + at::Tensor ln_out, float eps, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine; -std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, - const at::Tensor &weight, - at::Tensor ln_out, - float eps, - const int sm_margin, - const bool zero_centered_gamma -) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); + DType itype = GetTransformerEngineDType(input.scalar_type()); - return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), - ln_out, at::Tensor(), at::Tensor(), - itype, sm_margin, zero_centered_gamma); + return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), ln_out, at::Tensor(), + at::Tensor(), itype, sm_margin, zero_centered_gamma); } - -at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, - const at::Tensor &weight, - float eps, - const int sm_margin, - const bool zero_centered_gamma -) { - // This is a specialized version of rmsnorm_fwd, optimized for inference, - // which only returns the normalized output. - std::vector out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma); - return out[0]; +at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, + const int sm_margin, const bool zero_centered_gamma) { + // This is a specialized version of rmsnorm_fwd, optimized for inference, + // which only returns the normalized output. + std::vector out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma); + return out[0]; } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c2afa06d1403dface56bb7362574179acddf8299..609341b519abdf587fa74834ab5221484451545c 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -6,173 +6,93 @@ #include -#include "../extensions.h" #include "../comm_gemm_overlap.h" +#include "../extensions.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Softmax functions m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD"); m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD"); m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward, - "Scaled Masked Softmax FWD"); + "Scaled Masked Softmax FWD"); m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward, - "Scaled Masked Softmax BWD"); - m.def("scaled_upper_triang_masked_softmax_forward", - &scaled_upper_triang_masked_softmax_forward, - "Scaled Upper-Triangular Masked Softmax FWD"); - m.def("scaled_upper_triang_masked_softmax_backward", - &scaled_upper_triang_masked_softmax_backward, - "Scaled Upper-Triangular Masked Softmax BWD"); + "Scaled Masked Softmax BWD"); + m.def("scaled_upper_triang_masked_softmax_forward", &scaled_upper_triang_masked_softmax_forward, + "Scaled Upper-Triangular Masked Softmax FWD"); + m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward, + "Scaled Upper-Triangular Masked Softmax BWD"); m.def("scaled_aligned_causal_masked_softmax_forward", - &scaled_aligned_causal_masked_softmax_forward, - "Scaled Bottom-Right Corner Aligned Masked Softmax FWD"); + &scaled_aligned_causal_masked_softmax_forward, + "Scaled Bottom-Right Corner Aligned Masked Softmax FWD"); m.def("scaled_aligned_causal_masked_softmax_backward", - &scaled_aligned_causal_masked_softmax_backward, - "Scaled Bottom-Right Corner Aligned Masked Softmax BWD"); + &scaled_aligned_causal_masked_softmax_backward, + "Scaled Bottom-Right Corner Aligned Masked Softmax BWD"); // Other granular functions - m.def("layernorm_fwd_fp8", - &layernorm_fwd_fp8, - "LN FWD FP8", - py::arg("input"), - py::arg("weight"), - py::arg("bias"), - py::arg("eps"), - py::arg("scale"), - py::arg("amax"), - py::arg("scale_inv"), - py::arg("otype"), - py::arg("sm_margin"), - py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, - py::arg("scale_inv_offset") = 0); - m.def("layernorm_fwd_fp8_noalloc", - &layernorm_fwd_fp8_noalloc, - "LN FWD FP8", - py::arg("input"), - py::arg("weight"), - py::arg("bias"), - py::arg("eps"), - py::arg("scale"), - py::arg("ln_out"), - py::arg("amax"), - py::arg("scale_inv"), - py::arg("otype"), - py::arg("sm_margin"), - py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, + m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), + py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8", py::arg("input"), + py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), + py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"), + py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("layernorm_bwd", &layernorm_bwd, "LN BWD"); m.def("layernorm_fwd", &layernorm_fwd, "LN FWD"); m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD"); - m.def("rmsnorm_fwd_fp8", - &rmsnorm_fwd_fp8, - "RMSNorm FWD FP8", - py::arg("input"), - py::arg("weight"), - py::arg("eps"), - py::arg("scale"), - py::arg("amax"), - py::arg("scale_inv"), - py::arg("otype"), - py::arg("sm_margin"), - py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, - py::arg("scale_inv_offset") = 0); - m.def("rmsnorm_fwd_fp8_noalloc", - &rmsnorm_fwd_fp8_noalloc, - "RMSNorm FWD FP8", - py::arg("input"), - py::arg("weight"), - py::arg("eps"), - py::arg("scale"), - py::arg("ln_out"), - py::arg("amax"), - py::arg("scale_inv"), - py::arg("otype"), - py::arg("sm_margin"), - py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, + m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8", py::arg("input"), py::arg("weight"), + py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("sm_margin"), py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, + py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8", py::arg("input"), + py::arg("weight"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), + py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"), + py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD"); m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); - m.def("fused_cast_transpose_noop", - &fused_cast_transpose_noop, - "Fused Cast + Transpose with noop option", - py::arg("input"), - py::arg("noop"), - py::arg("scale"), - py::arg("amax"), - py::arg("scale_inv"), - py::arg("input_cast"), - py::arg("input_transpose"), - py::arg("otype"), - py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, - py::arg("scale_inv_offset") = 0); - m.def("fused_cast_transpose_bgrad", - &fused_cast_transpose_bgrad, - "Fused Cast + Transpose + BGRAD", - py::arg("grad_output"), - py::arg("scale"), - py::arg("amax"), - py::arg("scale_inv"), - py::arg("otype"), - py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, - py::arg("scale_inv_offset") = 0); - m.def("fused_fp8_transpose_bgrad", - &fused_fp8_transpose_bgrad, - "Fused FP8 Transpose + BGRAD", - py::arg("grad_output"), - py::arg("scale"), - py::arg("amax"), - py::arg("scale_inv"), - py::arg("otype"), - py::arg("grad_bias_type"), - py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, - py::arg("scale_inv_offset") = 0); - m.def("fused_cast_transpose_bgrad_dgelu", - &fused_cast_transpose_bgrad_dgelu, - "Fused Cast + Transpose + BGRAD + DGELU", - py::arg("grad_output"), - py::arg("gelu_input"), - py::arg("scale"), - py::arg("amax"), - py::arg("scale_inv"), - py::arg("otype"), - py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, + m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, + "Fused Cast + Transpose with noop option", py::arg("input"), py::arg("noop"), + py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("input_cast"), + py::arg("input_transpose"), py::arg("otype"), py::arg("scale_offset") = 0, + py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD", + py::arg("grad_output"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), + py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, "Fused FP8 Transpose + BGRAD", + py::arg("grad_output"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), + py::arg("otype"), py::arg("grad_bias_type"), py::arg("scale_offset") = 0, + py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu, + "Fused Cast + Transpose + BGRAD + DGELU", py::arg("grad_output"), py::arg("gelu_input"), + py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, - "Fused Multi-tensor Cast + Transpose"); + "Fused Multi-tensor Cast + Transpose"); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8"); m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8"); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); m.def("te_gemm", &te_gemm, "CublasLt GEMM"); m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed QKV"); + "Fused Attention FP8/BF16/FP16 FWD with packed QKV"); m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed QKV"); + "Fused Attention FP8/BF16/FP16 BWD with packed QKV"); m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed KV"); + "Fused Attention FP8/BF16/FP16 FWD with packed KV"); m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed KV"); + "Fused Attention FP8/BF16/FP16 BWD with packed KV"); m.def("fused_attn_fwd", &fused_attn_fwd, - "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); + "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &fused_attn_bwd, - "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); + "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O"); m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop, - "Transpose with FP8 I/O with noop option."); + "Transpose with FP8 I/O with noop option."); m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output"); m.def("geglu", &geglu, "GeGLU with FP8 output"); @@ -190,8 +110,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); - m.def("fused_amax_and_scale_update_after_reduction", - &fused_amax_and_scale_update_after_reduction, + m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction"); // fused apply rope @@ -240,78 +159,78 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Data structures py::class_(m, "FP8TensorMeta") - .def(py::init<>()) - .def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) - .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) - .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); + .def(py::init<>()) + .def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) + .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) + .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); // comm+GEMM overlap w/ userbuffers m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks); py::enum_(m, "UbufOverlapAlgo") - .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) - .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) - .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) - .value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P) - .value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P) - .value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS) - .value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P) - .value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P); + .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) + .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) + .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) + .value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P) + .value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P) + .value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS) + .value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P) + .value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P); py::class_(m, "UbufCommOverlap") - .def(py::init()) - .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) - .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) - .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv) - .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs) - .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf) - .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) - .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output) - .def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm) - .def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap); + .def(py::init()) + .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) + .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) + .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv) + .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs) + .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf) + .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) + .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output) + .def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm) + .def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap); py::class_(m, "UbufP2PCommOverlap") - .def(py::init()) - .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag) - .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs) - .def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag) - .def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs) - .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) - .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output) - .def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf) - .def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm) - .def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap) - .def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv); + .def(py::init()) + .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag) + .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs) + .def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag) + .def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs) + .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) + .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output) + .def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf) + .def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm) + .def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap) + .def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv); py::enum_(m, "DType", py::module_local()) - .value("kByte", transformer_engine::DType::kByte) - .value("kInt32", transformer_engine::DType::kInt32) - .value("kFloat32", transformer_engine::DType::kFloat32) - .value("kFloat16", transformer_engine::DType::kFloat16) - .value("kBFloat16", transformer_engine::DType::kBFloat16) - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); + .value("kByte", transformer_engine::DType::kByte) + .value("kInt32", transformer_engine::DType::kInt32) + .value("kFloat32", transformer_engine::DType::kFloat32) + .value("kFloat16", transformer_engine::DType::kFloat16) + .value("kBFloat16", transformer_engine::DType::kBFloat16) + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); py::enum_(m, "FP8FwdTensors") - .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) - .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) - .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) - .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) - .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) - .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) - .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) - .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) - .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); + .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) + .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) + .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) + .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) + .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) + .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) + .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) + .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) + .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); py::enum_(m, "FP8BwdTensors") - .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) - .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) - .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) - .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) - .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) - .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); + .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) + .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) + .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) + .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) + .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) + .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); py::enum_(m, "NVTE_Bias_Type") .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index d5d8e2f7c854b2ceec440931bfb8e65690ca6535..a130169fe76bedad80babba1b1800c1281e0cde1 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -4,21 +4,17 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" - -#include - #include #include +#include + +#include "extensions.h" -void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, - std::vector amax_histories, - std::vector scales, - std::vector scale_invs, - const std::string &amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin) { +void fused_amax_and_scale_update_after_reduction( + const at::Tensor &amax_reduction_buffer, std::vector amax_histories, + std::vector scales, std::vector scale_invs, + const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { using namespace transformer_engine; size_t num_tensors = amax_histories.size(); std::vector t_amax_histories(num_tensors); @@ -51,12 +47,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio te_scale_invs[i] = reinterpret_cast(&t_scale_invs[i]); } nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( - makeTransformerEngineTensor(amax_reduction_buffer).data(), - te_amax_histories, - te_scales, - te_scale_invs, - amax_compute_algo.c_str(), - static_cast(fp8_dtype), - margin, - at::cuda::getCurrentCUDAStream()); + makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, + te_scale_invs, amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, + at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cu b/transformer_engine/pytorch/csrc/extensions/softmax.cu index 6bae5f6b4669784a0533e09c06506f24222de4b2..acb68543d85b2c0b94cdeacec29ce062d3f841a6 100644 --- a/transformer_engine/pytorch/csrc/extensions/softmax.cu +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cu @@ -6,25 +6,23 @@ #include "extensions.h" -at::Tensor scaled_softmax_forward(at::Tensor input, - float scale_factor -) { - using namespace transformer_engine; - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - const int batches = input.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - - AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); - AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); - AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); - - // Output +at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { + using namespace transformer_engine; + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + const int batches = input.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + + AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); + + // Output auto act_options = input.options().requires_grad(false); auto softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); @@ -38,249 +36,212 @@ at::Tensor scaled_softmax_forward(at::Tensor input, return softmax_results; } +at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, + float scale_factor) { + using namespace transformer_engine; -at::Tensor scaled_softmax_backward(at::Tensor output_grad_, - at::Tensor softmax_results_, - float scale_factor -) { - using namespace transformer_engine; + auto output_grads = output_grad_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); + AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); - AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - // Produce gradients in place. - nvte_scaled_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); + // Produce gradients in place. + nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), + output_grads_cu.data(), scale_factor, + at::cuda::getCurrentCUDAStream()); - return output_grads; + return output_grads; } +at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { + using namespace transformer_engine; + + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + if (!input.is_contiguous()) input = input.contiguous(); + if (!mask.is_contiguous()) mask = mask.contiguous(); + + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + + AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); + TORCH_CHECK(pad_batches == 1 || pad_batches == batches); + TORCH_CHECK(mask.size(1) == 1); + TORCH_CHECK(mask.size(2) == query_seq_len); + TORCH_CHECK(mask.size(3) == key_seq_len); -at::Tensor scaled_masked_softmax_forward(at::Tensor input, - at::Tensor mask, - float scale_factor -) { - using namespace transformer_engine; - - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - if (!input.is_contiguous()) - input = input.contiguous(); - if (!mask.is_contiguous()) - mask = mask.contiguous(); - - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - - AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); - AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); - AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); - TORCH_CHECK(pad_batches == 1 || pad_batches == batches); - TORCH_CHECK(mask.size(1) == 1); - TORCH_CHECK(mask.size(2) == query_seq_len); - TORCH_CHECK(mask.size(3) == key_seq_len); - - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - - auto input_cu = makeTransformerEngineTensor(input); - auto mask_cu = makeTransformerEngineTensor(mask); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - - nvte_scaled_masked_softmax_forward( - input_cu.data(), mask_cu.data(), softmax_results_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); - - return softmax_results; -} + auto act_options = input.options().requires_grad(false); + auto softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto mask_cu = makeTransformerEngineTensor(mask); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(), + scale_factor, at::cuda::getCurrentCUDAStream()); -at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, - at::Tensor softmax_results_, - float scale_factor -) { - using namespace transformer_engine; + return softmax_results; +} - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); +at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, + float scale_factor) { + using namespace transformer_engine; - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + auto output_grads = output_grad_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); - // Produce gradients in place. - nvte_scaled_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - return output_grads; -} + // Produce gradients in place. + nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), + output_grads_cu.data(), scale_factor, + at::cuda::getCurrentCUDAStream()); + return output_grads; +} -at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, - float scale_factor -) { - using namespace transformer_engine; +at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { + using namespace transformer_engine; - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - AT_ASSERTM(seq_len <= 16384, "Sequence length must be 16384 or less"); + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + AT_ASSERTM(seq_len <= 16384, "Sequence length must be 16384 or less"); - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); + // Output + auto act_options = input.options().requires_grad(false); + auto softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); - auto input_cu = makeTransformerEngineTensor(input); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + auto input_cu = makeTransformerEngineTensor(input); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), - softmax_results_cu.data(), - scale_factor, - at::cuda::getCurrentCUDAStream()); + nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), + scale_factor, at::cuda::getCurrentCUDAStream()); - return softmax_results; + return softmax_results; } - at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, at::Tensor softmax_results_, - float scale_factor -) { - using namespace transformer_engine; + float scale_factor) { + using namespace transformer_engine; - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); - TORCH_CHECK(output_grads.size(1) == output_grads.size(2)); + TORCH_CHECK(output_grads.size(1) == output_grads.size(2)); - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - // Produce gradients in place. - nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(), - softmax_results_cu.data(), - output_grads_cu.data(), - scale_factor, - at::cuda::getCurrentCUDAStream()); + // Produce gradients in place. + nvte_scaled_upper_triang_masked_softmax_backward( + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, + at::cuda::getCurrentCUDAStream()); return output_grads; } +at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { + using namespace transformer_engine; + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); -at::Tensor scaled_aligned_causal_masked_softmax_forward( - at::Tensor input, - float scale_factor -) { - using namespace transformer_engine; - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - const int batches = input.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - - AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); - AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); - AT_ASSERTM(query_seq_len >= 1, "Query sequence length must be greater or equal to 1"); - - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - auto input_cu = makeTransformerEngineTensor(input); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - - nvte_scaled_aligned_causal_masked_softmax_forward( - input_cu.data(), - softmax_results_cu.data(), - scale_factor, - at::cuda::getCurrentCUDAStream()); - - return softmax_results; -} + const int batches = input.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + + AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + AT_ASSERTM(query_seq_len >= 1, "Query sequence length must be greater or equal to 1"); + + // Output + auto act_options = input.options().requires_grad(false); + auto softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + auto input_cu = makeTransformerEngineTensor(input); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); -at::Tensor scaled_aligned_causal_masked_softmax_backward( - at::Tensor output_grad_, - at::Tensor softmax_results_, - float scale_factor -) { - using namespace transformer_engine; + nvte_scaled_aligned_causal_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), + scale_factor, at::cuda::getCurrentCUDAStream()); - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); + return softmax_results; +} - AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); +at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, + at::Tensor softmax_results_, + float scale_factor) { + using namespace transformer_engine; - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + auto output_grads = output_grad_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); - // Produce gradients in place. - nvte_scaled_aligned_causal_masked_softmax_backward( - output_grads_cu.data(), - softmax_results_cu.data(), - output_grads_cu.data(), - scale_factor, - at::cuda::getCurrentCUDAStream()); + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); - return output_grads; + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_aligned_causal_masked_softmax_backward( + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, + at::cuda::getCurrentCUDAStream()); + + return output_grads; } diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index 20215a65f07a052f5f4b800f70f30f52b10809b5..da4da2e190c0ca8fa75dce31fa9c9ec1bb575465 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -6,44 +6,30 @@ #include "extensions.h" -void fused_cast_transpose(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - at::Tensor input_cast, - at::Tensor input_transpose, - transformer_engine::DType otype -) { +void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + at::Tensor input_cast, at::Tensor input_transpose, + transformer_engine::DType otype) { using namespace transformer_engine; size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); - auto input_cu = makeTransformerEngineTensor(input); - auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); + auto input_cu = makeTransformerEngineTensor(input); + auto output_cast_cu = + makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); + auto output_transpose_cu = + makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), at::cuda::getCurrentCUDAStream()); } - -void fused_cast_transpose_noop(at::Tensor input, - at::Tensor noop, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - at::Tensor input_cast, - at::Tensor input_transpose, - transformer_engine::DType otype, - int scale_offset, - int amax_offset, - int scale_inv_offset -) { +void fused_cast_transpose_noop(at::Tensor input, at::Tensor noop, at::Tensor scale, at::Tensor amax, + at::Tensor scale_inv, at::Tensor input_cast, + at::Tensor input_transpose, transformer_engine::DType otype, + int scale_offset, int amax_offset, int scale_inv_offset) { using namespace transformer_engine; // Tensor dimensions @@ -56,39 +42,23 @@ void fused_cast_transpose_noop(at::Tensor input, void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); // Construct Transformer Engine tensors - auto input_cu = makeTransformerEngineTensor(input); - auto noop_cu = makeTransformerEngineTensor(noop); - auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), - {M, N}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); - auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), - {N, M}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); + auto input_cu = makeTransformerEngineTensor(input); + auto noop_cu = makeTransformerEngineTensor(noop); + auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); + auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, + amax_dptr, scale_dptr, scale_inv_dptr); // Launch kernel - nvte_cast_transpose_with_noop(input_cu.data(), - noop_cu.data(), - output_cast_cu.data(), - output_transpose_cu.data(), - at::cuda::getCurrentCUDAStream()); + nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(), + output_transpose_cu.data(), at::cuda::getCurrentCUDAStream()); } - -std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, +std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - int scale_offset, - int amax_offset, - int scale_inv_offset -) { + int scale_offset, int amax_offset, + int scale_inv_offset) { using namespace transformer_engine; // Tensor dimensions @@ -98,12 +68,10 @@ std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, // Allocate output tensors DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto grad_output_cast = allocateTorchTensor(grad_output.size(0), - grad_output.size(1), - DType::kByte); - auto grad_output_transpose = allocateTorchTensor(grad_output.size(1), - grad_output.size(0), - DType::kByte); + auto grad_output_cast = + allocateTorchTensor(grad_output.size(0), grad_output.size(1), DType::kByte); + auto grad_output_transpose = + allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); // Return immediately if tensors are empty if (M == 0 || N == 0) { @@ -116,50 +84,34 @@ std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); // Construct Transformer Engine tensors - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), - {M, N}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(), - {N, M}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); + auto input_cu = makeTransformerEngineTensor(grad_output); + auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N}, otype, + amax_dptr, scale_dptr, scale_inv_dptr); + auto transposed_output_cu = makeTransformerEngineTensor( + grad_output_transpose.data_ptr(), {N, M}, otype, amax_dptr, scale_dptr, scale_inv_dptr); + auto dbias_cu = makeTransformerEngineTensor(grad_bias); // Query workspace size and allocate workspace transformer_engine::TensorWrapper workspace; - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), transposed_output_cu.data(), + dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // Launch kernel - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), transposed_output_cu.data(), + dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); return {grad_bias, grad_output_cast, grad_output_transpose}; } - -std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, +std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, transformer_engine::DType grad_bias_type, - int scale_offset, - int amax_offset, - int scale_inv_offset -) { + int scale_offset, int amax_offset, + int scale_inv_offset) { using namespace transformer_engine; // Tensor dimensions @@ -173,21 +125,12 @@ std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, // Construct Transformer Engine tensors auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type); - auto grad_output_transpose = allocateTorchTensor(grad_output.size(1), - grad_output.size(0), - DType::kByte); - auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), - {M, N}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(), - {N, M}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); + auto grad_output_transpose = + allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); + auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); + auto transposed_output_cu = makeTransformerEngineTensor( + grad_output_transpose.data_ptr(), {N, M}, otype, amax_dptr, scale_dptr, scale_inv_dptr); auto dbias_cu = makeTransformerEngineTensor(grad_bias); // Query workspace size and allocate workspace @@ -195,9 +138,8 @@ std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // Launch kernel nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), @@ -206,18 +148,12 @@ std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, return {grad_bias, grad_output_transpose}; } - - std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, - at::Tensor gelu_input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, + at::Tensor gelu_input, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - int scale_offset, - int amax_offset, - int scale_inv_offset -) { + int scale_offset, int amax_offset, + int scale_inv_offset) { using namespace transformer_engine; // Tensor dimensions @@ -232,71 +168,51 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, // Construct Transformer Engine tensors DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto dgelu = allocateTorchTensor(grad_output.size(0), - grad_output.size(1), - DType::kByte); - auto dgelu_transpose = allocateTorchTensor(grad_output.size(1), - grad_output.size(0), - DType::kByte); - auto gelu_input_cu = makeTransformerEngineTensor(gelu_input); - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), - {M, N}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), - {N, M}, - otype, - amax_dptr, - scale_dptr, - scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); + auto dgelu = allocateTorchTensor(grad_output.size(0), grad_output.size(1), DType::kByte); + auto dgelu_transpose = + allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); + auto gelu_input_cu = makeTransformerEngineTensor(gelu_input); + auto input_cu = makeTransformerEngineTensor(grad_output); + auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); + auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M}, otype, + amax_dptr, scale_dptr, scale_inv_dptr); + auto dbias_cu = makeTransformerEngineTensor(grad_bias); // Query workspace size and allocate workspace transformer_engine::TensorWrapper workspace; - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), - cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), + nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), + transposed_output_cu.data(), dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // Launch kernel - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), - cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), + nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), + transposed_output_cu.data(), dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); return {grad_bias, dgelu, dgelu_transpose}; } - void fused_multi_cast_transpose(std::vector input_list, std::vector scale_list, std::vector cast_output_list, std::vector transposed_output_list, std::vector amax_list, std::vector scale_inv_list, - transformer_engine::DType otype -) { + transformer_engine::DType otype) { using namespace transformer_engine; // Extract properties from PyTorch tensors - std::vector input_dptr_list, scale_dptr_list, - cast_output_dptr_list, transposed_output_dptr_list, - amax_dptr_list, scale_inv_dptr_list; - std::vector> input_shape_list, scale_shape_list, - cast_output_shape_list, transposed_output_shape_list, - amax_shape_list, scale_inv_shape_list; - std::vector input_type_list, scale_type_list, - cast_output_type_list, transposed_output_type_list, - amax_type_list, scale_inv_type_list; - auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, - std::vector& dptr_list, + std::vector input_dptr_list, scale_dptr_list, cast_output_dptr_list, + transposed_output_dptr_list, amax_dptr_list, scale_inv_dptr_list; + std::vector> input_shape_list, scale_shape_list, cast_output_shape_list, + transposed_output_shape_list, amax_shape_list, scale_inv_shape_list; + std::vector input_type_list, scale_type_list, cast_output_type_list, + transposed_output_type_list, amax_type_list, scale_inv_type_list; + auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, std::vector& dptr_list, std::vector>& shape_list) { dptr_list.push_back(tensor.data_ptr()); shape_list.push_back({}); @@ -304,8 +220,7 @@ void fused_multi_cast_transpose(std::vector input_list, shape_list.back().push_back(tensor.size(d)); } }; - auto extract_tensor_props = [](at::Tensor& tensor, - std::vector& dptr_list, + auto extract_tensor_props = [](at::Tensor& tensor, std::vector& dptr_list, std::vector>& shape_list, std::vector& type_list) { dptr_list.push_back(tensor.data_ptr()); @@ -316,68 +231,41 @@ void fused_multi_cast_transpose(std::vector input_list, type_list.push_back(GetTransformerEngineDType(tensor.scalar_type())); }; for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { - extract_tensor_props(input_list[tensor_id], - input_dptr_list, - input_shape_list, - input_type_list); - extract_tensor_props(scale_list[tensor_id], - scale_dptr_list, - scale_shape_list, - scale_type_list); - extract_tensor_props_skip_dtype(cast_output_list[tensor_id], - cast_output_dptr_list, + extract_tensor_props(input_list[tensor_id], input_dptr_list, input_shape_list, input_type_list); + extract_tensor_props(scale_list[tensor_id], scale_dptr_list, scale_shape_list, scale_type_list); + extract_tensor_props_skip_dtype(cast_output_list[tensor_id], cast_output_dptr_list, cast_output_shape_list); cast_output_type_list.push_back(otype); - extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], - transposed_output_dptr_list, + extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], transposed_output_dptr_list, transposed_output_shape_list); transposed_output_type_list.push_back(otype); - extract_tensor_props(amax_list[tensor_id], - amax_dptr_list, - amax_shape_list, - amax_type_list); - extract_tensor_props(scale_inv_list[tensor_id], - scale_inv_dptr_list, - scale_inv_shape_list, + extract_tensor_props(amax_list[tensor_id], amax_dptr_list, amax_shape_list, amax_type_list); + extract_tensor_props(scale_inv_list[tensor_id], scale_inv_dptr_list, scale_inv_shape_list, scale_inv_type_list); } transformer_engine::TensorWrapper workspace; // Construct TE tensors - std::vector nvte_input_list, - nvte_cast_output_list, nvte_transposed_output_list; + std::vector nvte_input_list, nvte_cast_output_list, nvte_transposed_output_list; std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, - const std::vector& shape, - transformer_engine::DType dtype, - void* amax_dptr, - void* scale_dptr, - void* scale_inv_dptr) - -> NVTETensor { - tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, - scale_dptr, scale_inv_dptr)); + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype, void* amax_dptr, + void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { + tensor_wrappers.emplace_back( + makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); return tensor_wrappers.back().data(); }; for (size_t i = 0; i < input_dptr_list.size(); ++i) { - nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], - input_shape_list[i], - input_type_list[i], - nullptr, - nullptr, - nullptr)); - nvte_cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i], - cast_output_shape_list[i], - cast_output_type_list[i], - amax_dptr_list[i], - scale_dptr_list[i], - scale_inv_dptr_list[i])); - nvte_transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i], - transposed_output_shape_list[i], - transposed_output_type_list[i], - amax_dptr_list[i], - scale_dptr_list[i], - scale_inv_dptr_list[i])); + nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], input_shape_list[i], + input_type_list[i], nullptr, nullptr, nullptr)); + nvte_cast_output_list.emplace_back( + make_tensor(cast_output_dptr_list[i], cast_output_shape_list[i], cast_output_type_list[i], + amax_dptr_list[i], scale_dptr_list[i], scale_inv_dptr_list[i])); + nvte_transposed_output_list.emplace_back( + make_tensor(transposed_output_dptr_list[i], transposed_output_shape_list[i], + transposed_output_type_list[i], amax_dptr_list[i], scale_dptr_list[i], + scale_inv_dptr_list[i])); } // Check tensor lists @@ -387,30 +275,21 @@ void fused_multi_cast_transpose(std::vector input_list, "Number of input and T output tensors must match"); // Launch TE kernel - nvte_multi_cast_transpose(nvte_input_list.size(), - nvte_input_list.data(), - nvte_cast_output_list.data(), - nvte_transposed_output_list.data(), + nvte_multi_cast_transpose(nvte_input_list.size(), nvte_input_list.data(), + nvte_cast_output_list.data(), nvte_transposed_output_list.data(), at::cuda::getCurrentCUDAStream()); } - -at::Tensor fp8_transpose(at::Tensor input, - transformer_engine::DType otype -) { +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); - if (M == 0 || N == 0) - return input; + if (M == 0 || N == 0) return input; - auto output = - allocateTorchTensor(input.size(1), - input.size(0), - DType::kByte); + auto output = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); @@ -418,38 +297,29 @@ at::Tensor fp8_transpose(at::Tensor input, return output; } - -void fp8_transpose_noalloc(at::Tensor input, - at::Tensor output, - transformer_engine::DType otype -) { +void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype) { using namespace transformer_engine; size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); } - -void fp8_transpose_noalloc_noop(at::Tensor input, - at::Tensor output, - at::Tensor noop, - transformer_engine::DType otype -) { +void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, + transformer_engine::DType otype) { using namespace transformer_engine; size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto noop_cu = makeTransformerEngineTensor(noop); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto noop_cu = makeTransformerEngineTensor(noop); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); - nvte_transpose_with_noop( - input_cu.data(), noop_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); + nvte_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp old mode 100755 new mode 100644 index bf24e81ee44a2dec345d5b4e93d2ee3bd77ed2ac..d38bcb2829bc1845db3c49873e641b7299aa651f --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -4,343 +4,244 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include "extensions.h" #include #include -#include "common/util/system.h" +#include + #include "common/util/cuda_runtime.h" +#include "common/util/system.h" +#include "extensions.h" namespace { - transformer_engine::DType reverse_map_dtype(int64_t dtype) { - if (dtype >= 0 && dtype < static_cast(transformer_engine::DType::kNumTypes)) { - return static_cast(dtype); - } else { - NVTE_ERROR("Type not supported."); - } +transformer_engine::DType reverse_map_dtype(int64_t dtype) { + if (dtype >= 0 && dtype < static_cast(transformer_engine::DType::kNumTypes)) { + return static_cast(dtype); + } else { + NVTE_ERROR("Type not supported."); } +} } // namespace - -at::Tensor cast_to_fp8_ts(const at::Tensor &input, - const at::Tensor &scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype) { +at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, + at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = cast_to_fp8(input, - scale[fp8_tensor], - amax[0][fp8_tensor], - scale_inv[fp8_tensor], - otype_arg); + at::Tensor output = + cast_to_fp8(input, scale[fp8_tensor], amax[0][fp8_tensor], scale_inv[fp8_tensor], otype_arg); return output; } - -at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, - const at::Tensor &scale, - at::Tensor output, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype) { +at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &scale, + at::Tensor output, at::Tensor amax, at::Tensor scale_inv, + int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); - cast_to_fp8_noalloc(input, - scale[fp8_tensor], - output, - amax[0][fp8_tensor], - scale_inv[fp8_tensor], + cast_to_fp8_noalloc(input, scale[fp8_tensor], output, amax[0][fp8_tensor], scale_inv[fp8_tensor], otype_arg); return output; } - -at::Tensor cast_from_fp8_ts(const at::Tensor &input, - const at::Tensor &scale_inv, - int64_t fp8_tensor, - int64_t itype, - int64_t otype) { +at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv, + int64_t fp8_tensor, int64_t itype, int64_t otype) { transformer_engine::DType itype_arg = reverse_map_dtype(itype); transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = cast_from_fp8(input, - scale_inv[fp8_tensor], - itype_arg, - otype_arg); + at::Tensor output = cast_from_fp8(input, scale_inv[fp8_tensor], itype_arg, otype_arg); return output; } - -at::Tensor gelu_ts(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype) { +at::Tensor gelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor s, a, s_inv; if (scale.numel()) { - s = scale[fp8_tensor]; + s = scale[fp8_tensor]; } else { - s = scale; + s = scale; } if (amax.numel()) { - a = amax[0][fp8_tensor]; + a = amax[0][fp8_tensor]; } else { - a = amax; + a = amax; } if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; + s_inv = scale_inv[fp8_tensor]; } else { - s_inv = scale_inv; + s_inv = scale_inv; } - at::Tensor output = gelu(input, - s, - a, - s_inv, - otype_arg); + at::Tensor output = gelu(input, s, a, s_inv, otype_arg); return output; } - -at::Tensor relu_ts(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype) { +at::Tensor relu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor s, a, s_inv; if (scale.numel()) { - s = scale[fp8_tensor]; + s = scale[fp8_tensor]; } else { - s = scale; + s = scale; } if (amax.numel()) { - a = amax[0][fp8_tensor]; + a = amax[0][fp8_tensor]; } else { - a = amax; + a = amax; } if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; + s_inv = scale_inv[fp8_tensor]; } else { - s_inv = scale_inv; + s_inv = scale_inv; } - at::Tensor output = relu(input, - s, - a, - s_inv, - otype_arg); + at::Tensor output = relu(input, s, a, s_inv, otype_arg); return output; } - -at::Tensor reglu_ts(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype) { +at::Tensor reglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor s, a, s_inv; if (scale.numel()) { - s = scale[fp8_tensor]; + s = scale[fp8_tensor]; } else { - s = scale; + s = scale; } if (amax.numel()) { - a = amax[0][fp8_tensor]; + a = amax[0][fp8_tensor]; } else { - a = amax; + a = amax; } if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; + s_inv = scale_inv[fp8_tensor]; } else { - s_inv = scale_inv; + s_inv = scale_inv; } - at::Tensor output = reglu(input, - s, - a, - s_inv, - otype_arg); + at::Tensor output = reglu(input, s, a, s_inv, otype_arg); return output; } - -at::Tensor geglu_ts(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype) { +at::Tensor geglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor s, a, s_inv; if (scale.numel()) { - s = scale[fp8_tensor]; + s = scale[fp8_tensor]; } else { - s = scale; + s = scale; } if (amax.numel()) { - a = amax[0][fp8_tensor]; + a = amax[0][fp8_tensor]; } else { - a = amax; + a = amax; } if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; + s_inv = scale_inv[fp8_tensor]; } else { - s_inv = scale_inv; + s_inv = scale_inv; } - at::Tensor output = geglu(input, - s, - a, - s_inv, - otype_arg); + at::Tensor output = geglu(input, s, a, s_inv, otype_arg); return output; } - -at::Tensor swiglu_ts(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype) { +at::Tensor swiglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor s, a, s_inv; if (scale.numel()) { - s = scale[fp8_tensor]; + s = scale[fp8_tensor]; } else { - s = scale; + s = scale; } if (amax.numel()) { - a = amax[0][fp8_tensor]; + a = amax[0][fp8_tensor]; } else { - a = amax; + a = amax; } if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; + s_inv = scale_inv[fp8_tensor]; } else { - s_inv = scale_inv; + s_inv = scale_inv; } - at::Tensor output = swiglu(input, - s, - a, - s_inv, - otype_arg); + at::Tensor output = swiglu(input, s, a, s_inv, otype_arg); return output; } -at::Tensor qgelu_ts(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype) { +at::Tensor qgelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor s, a, s_inv; if (scale.numel()) { - s = scale[fp8_tensor]; + s = scale[fp8_tensor]; } else { - s = scale; + s = scale; } if (amax.numel()) { - a = amax[0][fp8_tensor]; + a = amax[0][fp8_tensor]; } else { - a = amax; + a = amax; } if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; + s_inv = scale_inv[fp8_tensor]; } else { - s_inv = scale_inv; + s_inv = scale_inv; } - at::Tensor output = qgelu(input, - s, - a, - s_inv, - otype_arg); + at::Tensor output = qgelu(input, s, a, s_inv, otype_arg); return output; } -at::Tensor srelu_ts(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype) { +at::Tensor srelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor s, a, s_inv; if (scale.numel()) { - s = scale[fp8_tensor]; + s = scale[fp8_tensor]; } else { - s = scale; + s = scale; } if (amax.numel()) { - a = amax[0][fp8_tensor]; + a = amax[0][fp8_tensor]; } else { - a = amax; + a = amax; } if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; + s_inv = scale_inv[fp8_tensor]; } else { - s_inv = scale_inv; + s_inv = scale_inv; } - at::Tensor output = srelu(input, - s, - a, - s_inv, - otype_arg); + at::Tensor output = srelu(input, s, a, s_inv, otype_arg); return output; } -at::Tensor te_gemm_ts(at::Tensor A, - at::Tensor A_scale_inverse, - int64_t A_fp8_tensor, - int64_t A_type, - int64_t transa, - at::Tensor B, - at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, - int64_t B_type, - int64_t transb, - at::Tensor D, - at::Tensor D_scale, - int64_t D_type, - at::Tensor D_amax, - at::Tensor bias, - int64_t bias_type, - at::Tensor pre_gelu_out, - int64_t grad, - at::Tensor workspace, - int64_t workspaceSize, - int64_t accumulate, +at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + int64_t A_type, int64_t transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, int64_t B_type, int64_t transb, at::Tensor D, + at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, at::Tensor bias, + int64_t bias_type, at::Tensor pre_gelu_out, int64_t grad, + at::Tensor workspace, int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) { // cast inputs to types accepted by te_gemm transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); @@ -360,135 +261,69 @@ at::Tensor te_gemm_ts(at::Tensor A, const int sm_count = transformer_engine::cuda::sm_count(); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - if (A_scale_inverse.numel()) - A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) - B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - te_gemm(A, - A_scale_inverse, - A_type_arg, - transa_arg, - B, - B_scale_inverse, - B_type_arg, - transb_arg, - D, - D_scale, - D_type_arg, - D_amax, - bias, - bias_type_arg, - pre_gelu_out, - grad_arg, - workspace, - workspaceSize_arg, - accumulate_arg, - use_split_accumulator_arg, - num_math_sms); + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + te_gemm(A, A_scale_inverse, A_type_arg, transa_arg, B, B_scale_inverse, B_type_arg, transb_arg, D, + D_scale, D_type_arg, D_amax, bias, bias_type_arg, pre_gelu_out, grad_arg, workspace, + workspaceSize_arg, accumulate_arg, use_split_accumulator_arg, num_math_sms); return D; } - -at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - double eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype, - const int64_t sm_margin, +at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, double eps, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, + int64_t otype, const int64_t sm_margin, const bool zero_centered_gamma) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); float eps_float = static_cast(eps); - at::Tensor output = layernorm_fwd_fp8_inf(input, - weight, - bias, - eps_float, - scale, - amax, - scale_inv, - otype_arg, - sm_margin, - zero_centered_gamma, - fp8_tensor, // scale_offset - fp8_tensor, // amax_offset + at::Tensor output = layernorm_fwd_fp8_inf(input, weight, bias, eps_float, scale, amax, scale_inv, + otype_arg, sm_margin, zero_centered_gamma, + fp8_tensor, // scale_offset + fp8_tensor, // amax_offset fp8_tensor); // scale_inv_offset return output; } - -at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - double eps, - const int64_t sm_margin, +at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, double eps, const int64_t sm_margin, const bool zero_centered_gamma) { float eps_float = static_cast(eps); - at::Tensor output = layernorm_fwd_inf(input, - weight, - bias, - eps_float, - sm_margin, - zero_centered_gamma); + at::Tensor output = + layernorm_fwd_inf(input, weight, bias, eps_float, sm_margin, zero_centered_gamma); return output; } - -at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, - const at::Tensor &weight, - double eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - int64_t fp8_tensor, - int64_t otype, - const int64_t sm_margin, +at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, + at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + int64_t fp8_tensor, int64_t otype, const int64_t sm_margin, const bool zero_centered_gamma) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); float eps_float = static_cast(eps); - at::Tensor output = rmsnorm_fwd_fp8_inf(input, - weight, - eps_float, - scale, - amax, - scale_inv, - otype_arg, - sm_margin, - zero_centered_gamma, - fp8_tensor, // scale_offset - fp8_tensor, // amax_offset + at::Tensor output = rmsnorm_fwd_fp8_inf(input, weight, eps_float, scale, amax, scale_inv, + otype_arg, sm_margin, zero_centered_gamma, + fp8_tensor, // scale_offset + fp8_tensor, // amax_offset fp8_tensor); // scale_inv_offset return output; } - -at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, - const at::Tensor &weight, - double eps, - const int64_t sm_margin, - const bool zero_centered_gamma) { +at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, + const int64_t sm_margin, const bool zero_centered_gamma) { float eps_float = static_cast(eps); - at::Tensor output = rmsnorm_fwd_inf(input, - weight, - eps_float, - sm_margin, - zero_centered_gamma); + at::Tensor output = rmsnorm_fwd_inf(input, weight, eps_float, sm_margin, zero_centered_gamma); return output; } - TORCH_LIBRARY(tex_ts, m) { m.def("cast_to_fp8_ts", &cast_to_fp8_ts); m.def("cast_to_fp8_noalloc_ts", &cast_to_fp8_noalloc_ts); diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc b/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc index 4972a80e9eaef986efc9430cdbf1a444e2fa22c3..c80709a7e7d24234cde7299e63184037e86e93db 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc +++ b/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc @@ -5,17 +5,18 @@ ************************************************************************/ #include "ipcsocket.h" + #include #include #include -#define WARN(...) \ +#define WARN(...) \ {} -#define TRACE(...) \ +#define TRACE(...) \ {} -#define SYSCHECK(...) \ +#define SYSCHECK(...) \ {} -#define EQCHECK(...) \ +#define EQCHECK(...) \ {} // Enable Linux abstract socket naming @@ -47,8 +48,7 @@ ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash, cliaddr.sun_family = AF_UNIX; // Create unique name for the socket. - size_t len = - snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); + size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); if (len > (sizeof(cliaddr.sun_path) - 1)) { WARN("UDS: Cannot bind provided name to socket. Name too large"); return ncclInternalError; @@ -64,8 +64,7 @@ ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash, cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick #endif if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) { - WARN("UDS: Binding to socket %s failed : %s (%d)", temp, strerror(errno), - errno); + WARN("UDS: Binding to socket %s failed : %s (%d)", temp, strerror(errno), errno); close(fd); return ncclSystemError; } @@ -88,8 +87,7 @@ ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd) { WARN("ncclSocketGetFd: pass NULL socket"); return ncclInvalidArgument; } - if (fd) - *fd = handle->fd; + if (fd) *fd = handle->fd; return ncclSuccess; } @@ -110,8 +108,7 @@ ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) { return ncclSuccess; } -ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, - int *recvFd) { +ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, int *recvFd) { struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; struct iovec iov[1]; @@ -144,15 +141,12 @@ ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, WARN("UDS: Receiving data over socket failed : %d", errno); return ncclSystemError; } - if (handle->abortFlag && *handle->abortFlag) - return ncclInternalError; + if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; } if (recvFd != NULL) { - if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && - (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) { - if ((cmptr->cmsg_level != SOL_SOCKET) || - (cmptr->cmsg_type != SCM_RIGHTS)) { + if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) { + if ((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) { WARN("UDS: Receiving data over socket failed"); return ncclSystemError; } @@ -162,8 +156,7 @@ ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, WARN("UDS: Receiving data over socket %s failed", handle->socketName); return ncclSystemError; } - TRACE(NCCL_INIT | NCCL_P2P, "UDS: Got recvFd %d from socket %s", *recvFd, - handle->socketName); + TRACE(NCCL_INIT | NCCL_P2P, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName); } return ncclSuccess; @@ -173,8 +166,8 @@ ncclResult_t ncclIpcSocketRecvFd(ncclIpcSocket *handle, int *recvFd) { return ncclIpcSocketRecvMsg(handle, NULL, 0, recvFd); } -ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, - const int sendFd, int rank, uint64_t hash) { +ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, const int sendFd, + int rank, uint64_t hash) { struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; struct iovec iov[1]; char temp[NCCL_IPC_SOCKNAME_LEN]; @@ -192,8 +185,7 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, bzero(&cliaddr, sizeof(cliaddr)); cliaddr.sun_family = AF_UNIX; - size_t len = - snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); + size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); if (len > (sizeof(cliaddr.sun_path) - 1)) { WARN("UDS: Cannot connect to provided name for socket. Name too large"); return ncclInternalError; @@ -204,8 +196,7 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick #endif - TRACE(NCCL_INIT, "UDS: Sending hdr %p len %d to UDS socket %s", hdr, hdrLen, - temp); + TRACE(NCCL_INIT, "UDS: Sending hdr %p len %d to UDS socket %s", hdr, hdrLen, temp); if (sendFd != -1) { TRACE(NCCL_INIT, "UDS: Sending fd %d to UDS socket %s", sendFd, temp); @@ -237,18 +228,15 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, ssize_t sendResult; while ((sendResult = sendmsg(handle->fd, &msg, 0)) < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { - WARN("UDS: Sending data over socket %s failed : %s (%d)", temp, - strerror(errno), errno); + WARN("UDS: Sending data over socket %s failed : %s (%d)", temp, strerror(errno), errno); return ncclSystemError; } - if (handle->abortFlag && *handle->abortFlag) - return ncclInternalError; + if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; } return ncclSuccess; } -ncclResult_t ncclIpcSocketSendFd(ncclIpcSocket *handle, const int sendFd, - int rank, uint64_t hash) { +ncclResult_t ncclIpcSocketSendFd(ncclIpcSocket *handle, const int sendFd, int rank, uint64_t hash) { return ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash); } diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h b/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h index 31c9c9419a202344963b056860bfcc947d5783ae..cc1e45febfeb913ee1d03a8a8a6b8ade50e17c10 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h +++ b/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h @@ -40,13 +40,13 @@ struct ncclIpcSocket { volatile uint32_t *abortFlag; }; -ncclResult_t ncclIpcSocketInit(struct ncclIpcSocket *handle, int rank, - uint64_t hash, volatile uint32_t *abortFlag); +ncclResult_t ncclIpcSocketInit(struct ncclIpcSocket *handle, int rank, uint64_t hash, + volatile uint32_t *abortFlag); ncclResult_t ncclIpcSocketClose(struct ncclIpcSocket *handle); ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd); ncclResult_t ncclIpcSocketRecvFd(struct ncclIpcSocket *handle, int *fd); -ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, - int rank, uint64_t hash); +ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, int rank, + uint64_t hash); #endif /* NCCL_IPCSOCKET_H */ diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp index e63bee121afce17fdeccef330ca5be965987f66d..bc93c61b3e0c7908b8d0951e03644416fd71faaa 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp @@ -4,19 +4,21 @@ * See LICENSE for license information. ************************************************************************/ -#include "ipcsocket.h" -#include "userbuffers.h" #include -#include #include #include -#include +#include #include #include #include #include #include -#include + +#include +#include + +#include "ipcsocket.h" +#include "userbuffers.h" #ifdef UB_MPI_BOOTSTRAP #include @@ -33,46 +35,46 @@ static char EXT_COMM_INTER[] = "inter"; int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); } -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) -#define CUCHECK(cmd) \ - do { \ - CUresult retval = cmd; \ - if (retval != CUDA_SUCCESS) { \ - const char *error_string; \ - cuGetErrorString(retval, &error_string); \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \ - exit(EXIT_FAILURE); \ - } \ +#define CUCHECK(cmd) \ + do { \ + CUresult retval = cmd; \ + if (retval != CUDA_SUCCESS) { \ + const char *error_string; \ + cuGetErrorString(retval, &error_string); \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \ + exit(EXIT_FAILURE); \ + } \ } while (0); -#define NVTE_UB_ERROR(x) \ - do { \ - throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ - " in function " + __func__ + ": " + x); \ +#define NVTE_UB_ERROR(x) \ + do { \ + throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ + " in function " + __func__ + ": " + x); \ } while (false) -#define NCCLCHECK(cmd) \ - do { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) { \ - printf("Failed, NCCL error %s:%d ''\n", __FILE__, __LINE__ /*,ncclGetErrorString(r)*/); \ - exit(EXIT_FAILURE); \ - } \ +#define NCCLCHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + printf("Failed, NCCL error %s:%d ''\n", __FILE__, __LINE__ /*,ncclGetErrorString(r)*/); \ + exit(EXIT_FAILURE); \ + } \ } while (0) -#define NCCLCHECKGOTO(call, RES, label) \ - do { \ - RES = call; \ - if (RES != ncclSuccess && RES != ncclInProgress) { \ - goto label; \ - } \ +#define NCCLCHECKGOTO(call, RES, label) \ + do { \ + RES = call; \ + if (RES != ncclSuccess && RES != ncclInProgress) { \ + goto label; \ + } \ } while (0); int pipe_rank(communicator *comm, int step) { @@ -89,12 +91,11 @@ int pipe_rank(communicator *comm, int step) { return newnode * numlocal + newlocal; } -int create_communicator_grouped2(communicator **comm, - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, - std::function ext_alloc_copy_allgather, - std::function ext_barrier, - std::function ext_free, - int pipegpus, int pipenodes, int tensorgpus, int tensornodes) { +int create_communicator_grouped2( + communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, std::function ext_alloc_copy_allgather, + std::function ext_barrier, std::function ext_free, int pipegpus, + int pipenodes, int tensorgpus, int tensornodes) { *comm = reinterpret_cast(malloc(sizeof(communicator))); (*comm)->comm_world = EXT_COMM_WORLD; @@ -117,22 +118,20 @@ int create_communicator_grouped2(communicator **comm, (*comm)->push = 1; (*comm)->use_ce = 0; (*comm)->cga_size = 2; - for (int i = 0; i < userbuffers_op_types; i++) - (*comm)->basecounter[i] = 0; + for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; (*comm)->head = 0; (*comm)->tail = 0; (*comm)->active_nreqs = 0; - for (int i = 0; i < userbuffers_op_types; i++) - (*comm)->active_req[i].active = -1; + for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; - int device_clock = 0; + int device_clock = 0; // 110 sec wait time by default int sec_timeout = getenv("UB_TIMEOUT") ? atoi(getenv("UB_TIMEOUT")) : 110; CUDACHECK(cudaDeviceGetAttribute(&device_clock, cudaDevAttrClockRate, cur_dev)); (*comm)->ub_timeout = 1000ull * device_clock * sec_timeout; if ((*comm)->myrank == 0) { - printf("UB_TIMEOUT is set to %d sec, %" PRIu64 " cycles, freq: %dkhz\n", - sec_timeout, (*comm)->ub_timeout, device_clock); + printf("UB_TIMEOUT is set to %d sec, %" PRIu64 " cycles, freq: %dkhz\n", sec_timeout, + (*comm)->ub_timeout, device_clock); } (*comm)->comm_intra = EXT_COMM_INTRA; @@ -142,22 +141,14 @@ int create_communicator_grouped2(communicator **comm, cpu_set_t cpuset; CPU_ZERO(&cpuset); int core; - if (mylocal == 0) - core = 50; - if (mylocal == 1) - core = 58; - if (mylocal == 2) - core = 18; - if (mylocal == 3) - core = 26; - if (mylocal == 4) - core = 114; - if (mylocal == 5) - core = 122; - if (mylocal == 6) - core = 82; - if (mylocal == 7) - core = 90; + if (mylocal == 0) core = 50; + if (mylocal == 1) core = 58; + if (mylocal == 2) core = 18; + if (mylocal == 3) core = 26; + if (mylocal == 4) core = 114; + if (mylocal == 5) core = 122; + if (mylocal == 6) core = 82; + if (mylocal == 7) core = 90; CPU_SET(core, &cpuset); if (!getenv("NVTE_NODOUBLE")) { @@ -166,8 +157,7 @@ int create_communicator_grouped2(communicator **comm, else CPU_SET(core + 128, &cpuset); } - if (getenv("NVTE_DOPIN")) - pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); + if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); if (ndev == numlocal) { // all visible devices if (cur_dev != mylocal) @@ -248,8 +238,7 @@ int create_communicator_grouped2(communicator **comm, NCCLCHECKGOTO(ncclIpcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); } } else { - for (int i = 0; i < (*comm)->ar2_nvrank; i++) - (*comm)->_barrier((*comm)->comm_intra); + for (int i = 0; i < (*comm)->ar2_nvrank; i++) (*comm)->_barrier((*comm)->comm_intra); NCCLCHECKGOTO(ncclIpcSocketRecvFd(&ipcSock, &fd), ret, error); for (int i = 0; i < (*comm)->ar2_nvsize - (*comm)->ar2_nvrank - 1; i++) (*comm)->_barrier((*comm)->comm_intra); @@ -273,11 +262,9 @@ int create_communicator_grouped2(communicator **comm, (*comm)->mc_baseptr = reinterpret_cast(mc_va); (*comm)->_barrier((*comm)->comm_world); - if (!(*comm)->myrank) - printf("MC initialized succesfully, window size = %ld\n", mc_maxsize); + if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize); } else { - if (!(*comm)->myrank) - printf("MC NOT initialized and used\n"); + if (!(*comm)->myrank) printf("MC NOT initialized and used\n"); (*comm)->mc_maxsize = 0; (*comm)->mc_offset = 0; (*comm)->use_mc = 0; @@ -318,42 +305,38 @@ int create_communicator_grouped2(communicator **comm, pthread_attr_setschedparam(&attr, ¶m); if (getenv("NVTE_UBDEBUG")) - printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP " - "%dx%d PIPE_ID %d/%d\n", - myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, - (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, - (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, - pipegpus * pipenodes); + printf( + "%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP " + "%dx%d PIPE_ID %d/%d\n", + myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, + (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, + (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, + pipegpus * pipenodes); fflush(NULL); return 0; } -int create_communicator_grouped(communicator **comm, - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, - std::function ext_alloc_copy_allgather, - std::function ext_barrier, - std::function ext_free, - int pipegpus, int pipenodes) { - return create_communicator_grouped2( - comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - ext_alloc_copy_allgather, ext_barrier, ext_free, - pipegpus, pipenodes, 1, 1); +int create_communicator_grouped( + communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, std::function ext_alloc_copy_allgather, + std::function ext_barrier, std::function ext_free, int pipegpus, + int pipenodes) { + return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + ext_alloc_copy_allgather, ext_barrier, ext_free, pipegpus, + pipenodes, 1, 1); } -int create_communicator(communicator **comm, - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, - std::function ext_alloc_copy_allgather, - std::function ext_barrier, - std::function ext_free) { - return create_communicator_grouped2( - comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - ext_alloc_copy_allgather, ext_barrier, ext_free, - 1, 1, 1, 1); +int create_communicator( + communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, std::function ext_alloc_copy_allgather, + std::function ext_barrier, std::function ext_free) { + return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + ext_alloc_copy_allgather, ext_barrier, ext_free, 1, 1, 1, 1); } -int create_communicator_grouped2_mpi(communicator **comm, - int pipegpus, int pipenodes, int tensorgpus, int tensornodes) { +int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, + int tensorgpus, int tensornodes) { #ifdef UB_MPI_BOOTSTRAP // get global numbers int myrank, numranks; @@ -375,10 +358,8 @@ int create_communicator_grouped2_mpi(communicator **comm, color = 0; for (int n = 0; n < size; n++) { - if (n > 0 && strcmp(host_names[n - 1], host_names[n])) - color++; - if (strcmp(host_name, host_names[n]) == 0) - break; + if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++; + if (strcmp(host_name, host_names[n]) == 0) break; } free(host_names); @@ -401,10 +382,9 @@ int create_communicator_grouped2_mpi(communicator **comm, MPI_Comm_rank(EXT_COMM_INTER, &mynode); // finally call the abstracted constructor with MPI info - return create_communicator_grouped2(comm, - myrank, numranks, mylocal, numlocal, mynode, numnodes, - &ub_alloc_copy_allgather, &ub_barrier, &ub_free, - pipegpus, pipenodes, tensorgpus, tensornodes); + return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + &ub_alloc_copy_allgather, &ub_barrier, &ub_free, pipegpus, + pipenodes, tensorgpus, tensornodes); #else NVTE_UB_ERROR(std::string("Bootstrapping Userbuffers with MPI requires ") + std::string("building Transformer Engine with UB_MPI_BOOTSTRAP=1")); @@ -466,8 +446,7 @@ void destroy_communicator_mpi(communicator *comm) { } int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { - if (comm->free_region > NVTE_MAX_REGIONS) - return -1; + if (comm->free_region > NVTE_MAX_REGIONS) return -1; int hndl = comm->free_region; comm->peer_ptr[hndl] = reinterpret_cast(malloc(sizeof(void *) * (comm->nvsize))); size_t aligned_size = bytes; @@ -559,8 +538,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * } CUCHECK(cuMemSetAccess(ptr, aligned_size * nranks, &accessDesc, 1)); - if (hndl == 0) - CUDACHECK(cudaMemset(comm->gpu_ptrs, 0, aligned_size)); + if (hndl == 0) CUDACHECK(cudaMemset(comm->gpu_ptrs, 0, aligned_size)); CUDACHECK( cudaMemcpy((reinterpret_cast(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)), remptrs, nranks * sizeof(void *), cudaMemcpyHostToDevice)); @@ -585,13 +563,12 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * CUDACHECK(cudaIpcGetMemHandle(&memhndl, *gpubuff)); cudaIpcMemHandle_t *tmp; - comm->_alloc_copy_allgather( - reinterpret_cast(&tmp), reinterpret_cast(&memhndl), - sizeof(cudaIpcMemHandle_t), comm->comm_intra); + comm->_alloc_copy_allgather(reinterpret_cast(&tmp), reinterpret_cast(&memhndl), + sizeof(cudaIpcMemHandle_t), comm->comm_intra); for (int i = 0; i < comm->nvsize; i++) { if (i != comm->nvrank) { - CUDACHECK(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) + CUDACHECK(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) cudaIpcMemLazyEnablePeerAccess)); } } diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index f98688b3eec46087d2ed7ceb842fb1f46880d3de..7632e69a0aef249d9a5b2a6d2b80e7e03999ad4c 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -14,39 +14,38 @@ #include #endif -#include "userbuffers.h" - -#include -#include #include #include +#include +#include + +#include "userbuffers.h" #define MAX_THREADS 1024 -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) -#define ATOMIC_CONSUMER(chunk) \ - if (counters) { \ - if (threadIdx.x == 0 && blockIdx.x == 0) { \ - while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ - } \ - ((unsigned int *)counters)[chunk] = 1; \ - asm volatile("fence.sc.gpu;\n"); \ - } \ - if (blockIdx.x == 0) \ - __syncthreads(); \ +#define ATOMIC_CONSUMER(chunk) \ + if (counters) { \ + if (threadIdx.x == 0 && blockIdx.x == 0) { \ + while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ + } \ + ((unsigned int *)counters)[chunk] = 1; \ + asm volatile("fence.sc.gpu;\n"); \ + } \ + if (blockIdx.x == 0) __syncthreads(); \ } -#define ATOMIC_PRODUCER(chunk) \ - if (counters) { \ - ((unsigned int *)counters)[chunk] = 0; \ +#define ATOMIC_PRODUCER(chunk) \ + if (counters) { \ + ((unsigned int *)counters)[chunk] = 0; \ } // Return true if producer > consumer, otherwise false while preventing integer overflow @@ -54,21 +53,21 @@ #define CHECK_IDS(producer, consumer) (((unsigned)(producer) - (unsigned)(consumer)) & (~INT_MAX)) // Strip the path from a full filename -#define FILENAME(file) ({ \ - const char* filename = file; \ - const char* basename = filename; \ - for (const char* ptr = filename; *ptr != '\0'; ptr++) { \ - if (*ptr == '/' || *ptr == '\\') { \ - basename = ptr + 1; \ - } \ - } \ - basename; \ -}) +#define FILENAME(file) \ + ({ \ + const char *filename = file; \ + const char *basename = filename; \ + for (const char *ptr = filename; *ptr != '\0'; ptr++) { \ + if (*ptr == '/' || *ptr == '\\') { \ + basename = ptr + 1; \ + } \ + } \ + basename; \ + }) // Printf to provide enough information so it is easier to attribute failures -#define UB_PRINT(message, ...) printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), \ - __FUNCTION__, \ - __LINE__, __VA_ARGS__) +#define UB_PRINT(message, ...) \ + printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), __FUNCTION__, __LINE__, __VA_ARGS__) // Report and error on timeout #define CHECK_TIMEOUT(t, timeout) ((clock64() - (t)) > timeout) @@ -111,8 +110,7 @@ __global__ void __launch_bounds__(MAX_THREADS) int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; @@ -132,8 +130,7 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) - s[j] += x[j]; + for (int j = 0; j < 8; j++) s[j] += x[j]; } #pragma unroll for (int i = 0; i < RANKS; i++) { @@ -143,8 +140,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) - __threadfence_system(); + if (threadIdx.x == 0) __threadfence_system(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -154,13 +150,12 @@ __global__ void __launch_bounds__(MAX_THREADS) while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Allreduce Gather: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } } - if (threadIdx.x == 0 && blockIdx.x == 0) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; } // fp16 inplace reduce kernel (Volta,Hopper) template @@ -188,8 +183,8 @@ __global__ void __launch_bounds__(MAX_THREADS) clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { - UB_PRINT("[%d ]Allreduce reduce-scatter:SM %d [%d]: expecting %d got %d", - myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); + UB_PRINT("[%d ]Allreduce reduce-scatter:SM %d [%d]: expecting %d got %d", myrank, + blockIdx.x, threadIdx.x, reduce_id, *flag); break; } } @@ -200,8 +195,7 @@ __global__ void __launch_bounds__(MAX_THREADS) int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; @@ -220,15 +214,13 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) - s[j] += x[j]; + for (int j = 0; j < 8; j++) s[j] += x[j]; } userptr[myrank][lineoffset + line] = sum; } __syncthreads(); - if (threadIdx.x == 0) - __threadfence(); + if (threadIdx.x == 0) __threadfence(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -238,7 +230,7 @@ __global__ void __launch_bounds__(MAX_THREADS) while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Allreduce gather: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } @@ -270,8 +262,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[myrank][lineoffset + line + blockDim.x * dest[i]] = val[i]; } } - if (threadIdx.x == 0 && blockIdx.x == 0) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; } // fp16 inplace reduce kernel (Ampere) template @@ -293,15 +284,14 @@ __global__ void __launch_bounds__(MAX_THREADS) reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - if (blockIdx.x == 0) - flagptr[physgpu] = reduce_id; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } @@ -310,15 +300,13 @@ __global__ void __launch_bounds__(MAX_THREADS) if (threadIdx.x == 0) { const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); - if (old_val + adder == NVTE_MAX_SMS * reduce_id) - lastSM = 1; + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; } int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; @@ -337,26 +325,20 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) - s[j] += x[j]; + for (int j = 0; j < 8; j++) s[j] += x[j]; } userptr[myrank][mylineoffset + line] = sum; } - if (threadIdx.x == 0 && lastSM) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 inplace reduce-scatter kernel template -__global__ void __launch_bounds__(MAX_THREADS) - userbuffers_fp16_sum_inplace_gpu_rr_rs_oop(const int op, const int flagoffset, - const int firstrank, const int myrank, - const int gpustep, const int mylineoffset, - const int totallines, const int rowlines, - const int skiplines, void **commbuff, - const int handleridx, void *outbuf, - const uint64_t ub_timeout) { +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop( + const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, + const int mylineoffset, const int totallines, const int rowlines, const int skiplines, + void **commbuff, const int handleridx, void *outbuf, const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -369,15 +351,14 @@ __global__ void __launch_bounds__(MAX_THREADS) reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - if (blockIdx.x == 0) - flagptr[physgpu] = reduce_id; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } @@ -386,15 +367,13 @@ __global__ void __launch_bounds__(MAX_THREADS) if (threadIdx.x == 0) { const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); - if (old_val + adder == NVTE_MAX_SMS * reduce_id) - lastSM = 1; + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; } int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; @@ -413,15 +392,13 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) - s[j] += x[j]; + for (int j = 0; j < 8; j++) s[j] += x[j]; } (reinterpret_cast(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum; } - if (threadIdx.x == 0 && lastSM) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) #if __CUDA_ARCH__ >= 900 @@ -451,7 +428,7 @@ __global__ void __launch_bounds__(MAX_THREADS) while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, - reduce_id, *flag); + reduce_id, *flag); break; } } @@ -508,8 +485,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) - __threadfence_system(); + if (threadIdx.x == 0) __threadfence_system(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -519,13 +495,12 @@ __global__ void __launch_bounds__(MAX_THREADS) while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > 2ull * TIMEOUT) { UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, - *flag); + *flag); break; } } } - if (threadIdx.x == 0 && blockIdx.x == 0) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; } // fp16 inplace reduce kernel (Hopper) MC template @@ -548,15 +523,14 @@ __global__ void __launch_bounds__(MAX_THREADS) reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - if (blockIdx.x == 0) - flagptr[physgpu] = reduce_id; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (CHECK_TIMEOUT(s, ub_timeout)) { - UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); - break; + break; } } } @@ -564,8 +538,7 @@ __global__ void __launch_bounds__(MAX_THREADS) if (threadIdx.x == 0) { const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); - if (old_val + adder == NVTE_MAX_SMS * reduce_id) - lastSM = 1; + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; } const int loop_step0 = blockDim.x * gridDim.x; const int loop_step = loop_step0 * UNROLL_MC; @@ -590,8 +563,7 @@ __global__ void __launch_bounds__(MAX_THREADS) : "memory"); #endif #pragma unroll - for (int i = 0; i < UNROLL_MC; i++) - localptr[mylineoffset + line + i * loop_step0] = val[i]; + for (int i = 0; i < UNROLL_MC; i++) localptr[mylineoffset + line + i * loop_step0] = val[i]; } for (int line = end_aligned; line < end_elem; line += loop_step0) { uint4 val; @@ -609,8 +581,7 @@ __global__ void __launch_bounds__(MAX_THREADS) localptr[mylineoffset + line] = val; } - if (threadIdx.x == 0 && lastSM) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 inplace reduce-scatter kernel MC template @@ -634,14 +605,13 @@ __global__ void __launch_bounds__(MAX_THREADS) reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - if (blockIdx.x == 0) - flagptr[physgpu] = reduce_id; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } @@ -650,8 +620,7 @@ __global__ void __launch_bounds__(MAX_THREADS) if (threadIdx.x == 0) { const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); - if (old_val + adder == NVTE_MAX_SMS * reduce_id) - lastSM = 1; + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; } const int loop_step0 = blockDim.x * gridDim.x; @@ -693,11 +662,10 @@ __global__ void __launch_bounds__(MAX_THREADS) : "l"(mc_ptr + (mylineoffset + line)) : "memory"); #endif - reinterpret_cast (outbuf)[(line / rowlines) * skiplines + (line % rowlines)] = val; + reinterpret_cast(outbuf)[(line / rowlines) * skiplines + (line % rowlines)] = val; } - if (threadIdx.x == 0 && lastSM) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) fp16 MC template @@ -731,8 +699,7 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int line = start_elem; line < end_aligned; line += loop_step) { uint4 val[UNROLL_MC]; #pragma unroll - for (int i = 0; i < UNROLL_MC; i++) - val[i] = localptr[mylineoffset + line + i * loop_step0]; + for (int i = 0; i < UNROLL_MC; i++) val[i] = localptr[mylineoffset + line + i * loop_step0]; #pragma unroll for (int i = 0; i < UNROLL_MC; i++) asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"( @@ -749,8 +716,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) - __threadfence_system(); + if (threadIdx.x == 0) __threadfence_system(); __syncthreads(); __shared__ int lastSM; @@ -764,16 +730,15 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (lastSM && threadIdx.x < RANKS) { - if (threadIdx.x == 0) - *reduceidptr = reduce_id; + if (threadIdx.x == 0) *reduceidptr = reduce_id; flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (CHECK_TIMEOUT(s, ub_timeout)) { - UB_PRINT("[%d] Allgather: SM %d [%d]: expecting %d got %d", - myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); - break; + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Allgather: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, threadIdx.x, + reduce_id, *flag); + break; } } } @@ -787,11 +752,14 @@ __global__ void __launch_bounds__(MAX_THREADS) const int numlines, void **commbuff, const int handleridx, float4 *mc_ptr) {} template -__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rs_oop( - const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, - const int mylineoffset, const int totallines, const int rowlines, const int skiplines, - void **commbuff, const int handleridx, void *outbuf, float4 *mc_ptr, - const uint64_t ub_timeout) {} +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset, + const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, + const int totallines, const int rowlines, + const int skiplines, void **commbuff, + const int handleridx, void *outbuf, float4 *mc_ptr, + const uint64_t ub_timeout) {} template __global__ void __launch_bounds__(MAX_THREADS) @@ -814,8 +782,7 @@ template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8( const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, const int rowlines, const int skiplines, - void **commbuff, const int handleridx, void *outbuf, float *scale, - const uint64_t ub_timeout) { + void **commbuff, const int handleridx, void *outbuf, float *scale, const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -830,15 +797,14 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - if (blockIdx.x == 0) - flagptr[physgpu] = reduce_id; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } @@ -847,14 +813,12 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ if (threadIdx.x == 0) { const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); - if (old_val + adder == NVTE_MAX_SMS * reduce_id) - lastSM = 1; + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; } int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; @@ -873,8 +837,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ for (int i = 0; i < RANKS; i++) { fp8type *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) - s[j] += hscale * (half)(x[j]); + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); } int hline = 2 * line; (reinterpret_cast(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] = @@ -884,8 +847,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ sum[1]; } - if (threadIdx.x == 0 && lastSM) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) (fp8->fp16) template @@ -919,15 +881,14 @@ __global__ void __launch_bounds__(MAX_THREADS) lastSM = 0; if (threadIdx.x < RANKS) { reduce_id++; - if (blockIdx.x == 0) - flagptr[physgpu] = reduce_id; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } @@ -936,15 +897,13 @@ __global__ void __launch_bounds__(MAX_THREADS) if (threadIdx.x == 0) { const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), /*numchunks * */ adder); - if (old_val + adder == NVTE_MAX_SMS * (reduce_id /* + numchunks*/)) - lastSM = 1; + if (old_val + adder == NVTE_MAX_SMS * (reduce_id /* + numchunks*/)) lastSM = 1; } int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; @@ -974,26 +933,20 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 0; i < RANKS; i++) { fp8type *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) - s[j] += hscale * (half)(x[j]); + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); } (reinterpret_cast(outbuf))[index1_out] = sum[0]; (reinterpret_cast(outbuf))[index2_out] = sum[1]; } } - if (threadIdx.x == 0 && lastSM) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) (fp8->fp16) template -__global__ void __launch_bounds__(MAX_THREADS) - userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride(const int op, const int flagoffset, - const int firstrank, const int myrank, - const int gpustep, const int mylineoffset, - const int totallines, const int rowlines, - const int skiplines, void **commbuff, - const int handleridx, void *outbuf, - const uint64_t ub_timeout) { +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride( + const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, + const int mylineoffset, const int totallines, const int rowlines, const int skiplines, + void **commbuff, const int handleridx, void *outbuf, const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -1007,15 +960,14 @@ __global__ void __launch_bounds__(MAX_THREADS) reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - if (blockIdx.x == 0) - flagptr[physgpu] = reduce_id; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } @@ -1024,15 +976,13 @@ __global__ void __launch_bounds__(MAX_THREADS) if (threadIdx.x == 0) { const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); - if (old_val + adder == NVTE_MAX_SMS * reduce_id) - lastSM = 1; + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; } int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; line += blockDim.x * gridDim.x) { @@ -1052,16 +1002,14 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) - s[j] += x[j]; + for (int j = 0; j < 8; j++) s[j] += x[j]; } int index_out = (line / rowlines) * skiplines + (line % rowlines); (reinterpret_cast(outbuf))[index_out] = sum; } - if (threadIdx.x == 0 && lastSM) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) fp16 template @@ -1102,15 +1050,14 @@ __global__ void __launch_bounds__(MAX_THREADS) reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - if (blockIdx.x == 0) - flagptr[physgpu] = reduce_id; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } @@ -1119,15 +1066,13 @@ __global__ void __launch_bounds__(MAX_THREADS) if (threadIdx.x == 0) { const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); - if (old_val + adder == NVTE_MAX_SMS * reduce_id) - lastSM = 1; + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; } int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; line += blockDim.x * gridDim.x) { @@ -1147,16 +1092,14 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) - s[j] += x[j]; + for (int j = 0; j < 8; j++) s[j] += x[j]; } int index_out = (line / rowlines) * skiplines + (line % rowlines); (reinterpret_cast(outbuf))[index_out] = sum; } - if (threadIdx.x == 0 && lastSM) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) fp16 template @@ -1198,15 +1141,14 @@ __global__ void __launch_bounds__(MAX_THREADS) reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - if (blockIdx.x == 0) - flagptr[physgpu] = reduce_id; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + threadIdx.x, reduce_id, *flag); break; } } @@ -1215,15 +1157,13 @@ __global__ void __launch_bounds__(MAX_THREADS) if (threadIdx.x == 0) { const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); - if (old_val + adder == NVTE_MAX_SMS * reduce_id) - lastSM = 1; + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; } int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; line += blockDim.x * gridDim.x) { @@ -1243,15 +1183,13 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) - s[j] += x[j]; + for (int j = 0; j < 8; j++) s[j] += x[j]; } int index_out = chunk_i * mylineoffset + (line / rowlines) * skiplines + (line % rowlines); (reinterpret_cast(outbuf))[index_out] = sum; } - if (threadIdx.x == 0 && lastSM) - *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } } // fp16 reduce-scatter kernel (out of place) fp16 @@ -1317,15 +1255,14 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (lastSM && threadIdx.x < RANKS) { - if (threadIdx.x == 0) - *reduceidptr = reduce_id; + if (threadIdx.x == 0) *reduceidptr = reduce_id; flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { UB_PRINT("[%d] Allgather: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, threadIdx.x, - reduce_id, *flag); + reduce_id, *flag); break; } } @@ -1380,8 +1317,7 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int line = start_elem; line < end_aligned; line += loop_step) { int4 val[UNROLLAG]; #pragma unroll - for (int j = 0; j < UNROLLAG; j++) - val[j] = localptr[mylineoffset + line + loop_step0 * j]; + for (int j = 0; j < UNROLLAG; j++) val[j] = localptr[mylineoffset + line + loop_step0 * j]; #pragma unroll for (int j = 0; j < UNROLLAG; j++) @@ -1400,8 +1336,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) - __threadfence_system(); + if (threadIdx.x == 0) __threadfence_system(); __syncthreads(); __shared__ int lastSM; @@ -1415,8 +1350,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (lastSM && threadIdx.x < RANKS) { - if (threadIdx.x == 0) - *reduceidptr = reduce_id; + if (threadIdx.x == 0) *reduceidptr = reduce_id; flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); @@ -1430,15 +1364,15 @@ __global__ void __launch_bounds__(MAX_THREADS) } } // fp16 inplace allgather kernel (Volta,Hopper) -#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ - cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ - cudaLaunchAttribute attribute_ub[2]; \ - attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ - attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ - attribute_ub[1].val.clusterDim.y = 1; \ - attribute_ub[1].val.clusterDim.z = 1; \ - attribute_ub[0].id = cudaLaunchAttributeCooperative; \ - cfg.attrs = attribute_ub; \ +#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[2]; \ + attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; #define callranks_ag(x) \ @@ -1464,226 +1398,226 @@ __global__ void __launch_bounds__(MAX_THREADS) kernelArgs)); \ } -#define callranks_agMC(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8 + arg4 * arg7; \ - void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ - int arg9 = handler * comm->nvsize; \ - uint4 *arg10 = reinterpret_cast(comm->mc_ptr[handler]); \ - uint64_t arg11 = comm->ub_timeout; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_ag), kernelArgs)); \ - } - -#define callranks_rs(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8 + arg4 * arg7; \ - void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ - int arg9 = handler * comm->nvsize; \ - uint64_t arg10 = comm->ub_timeout; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), kernelArgs)); \ - } - -#define callranks_rsMC(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8 + arg4 * arg7; \ - void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ - int arg9 = handler * comm->nvsize; \ - void *arg10 = comm->mc_ptr[handler]; \ - uint64_t arg11 = comm->ub_timeout; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs), kernelArgs)); \ - } - -#define callranks_rs_oop(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ - void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ - int arg11 = handler * comm->nvsize; \ - void *arg12 = output; \ - uint64_t arg13 = comm->ub_timeout; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ - kernelArgs)); \ - } - -#define callranks_rs_oop_fp8(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ - arg6 = offset / 16 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ - void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ - int arg11 = handler * comm->nvsize; \ - void *arg12 = output; \ - float *arg13 = scale; \ - uint64_t arg14 = comm->ub_timeout; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, \ - reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8), \ - kernelArgs)); \ - } - -#define callranks_rs_oopMC(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ - void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ - int arg11 = handler * comm->nvsize; \ - void *arg12 = output; \ - void *arg13 = comm->mc_ptr[handler]; \ - uint64_t arg14 = comm->ub_timeout; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop), \ - kernelArgs)); \ - } - -#define callranks_rs_oop_atomic_fp8(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ - arg6 = offset / 16, arg8 = rowelements / 8, arg9 = strideelements_out / 8, \ - arg10 = strideelements_in / 16; \ - void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ - int arg12 = handler * comm->nvsize; \ - void *arg13 = output; \ - float *arg14 = scale; \ - void *arg15 = counters; \ - int arg16 = numchunks, arg17 = atomicindex; \ - uint64_t arg18 = comm->ub_timeout; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ - reinterpret_cast(&arg15), reinterpret_cast(&arg16), \ - reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, \ - reinterpret_cast( \ - userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8), \ - kernelArgs)); \ - } - -#define callranks_rs_oop_stride(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8; \ - void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ - int arg11 = handler * comm->nvsize; \ - void *arg12 = output; \ - uint64_t arg13 = comm->ub_timeout; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride), \ - kernelArgs)); \ - } - -#define callranks_rs_oop_stride_atomic(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ - void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ - int arg12 = handler * comm->nvsize; \ - void *arg13 = output; \ - void *arg14 = counters; \ - uint64_t arg15 = comm->ub_timeout; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ - reinterpret_cast(&arg15)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, \ - reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic), \ - kernelArgs)); \ +#define callranks_agMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + uint4 *arg10 = reinterpret_cast(comm->mc_ptr[handler]); \ + uint64_t arg11 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_ag), kernelArgs)); \ + } + +#define callranks_rs(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + uint64_t arg10 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), kernelArgs)); \ + } + +#define callranks_rsMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *arg10 = comm->mc_ptr[handler]; \ + uint64_t arg11 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs), kernelArgs)); \ + } + +#define callranks_rs_oop(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + uint64_t arg13 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_fp8(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ + arg6 = offset / 16 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + float *arg13 = scale; \ + uint64_t arg14 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8), \ + kernelArgs)); \ + } + +#define callranks_rs_oopMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + void *arg13 = comm->mc_ptr[handler]; \ + uint64_t arg14 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_atomic_fp8(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ + arg6 = offset / 16, arg8 = rowelements / 8, arg9 = strideelements_out / 8, \ + arg10 = strideelements_in / 16; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + float *arg14 = scale; \ + void *arg15 = counters; \ + int arg16 = numchunks, arg17 = atomicindex; \ + uint64_t arg18 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15), reinterpret_cast(&arg16), \ + reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast( \ + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_stride(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + uint64_t arg13 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_stride_atomic(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + void *arg14 = counters; \ + uint64_t arg15 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic), \ + kernelArgs)); \ } #define callranks_rs_oop_stride_multiatomic(x) \ @@ -1726,12 +1660,10 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 64) - return; + if (elements < 64) return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; + if (warps < ar_nvsize) warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) @@ -1749,12 +1681,10 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 64) - return; + if (elements < 64) return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; + if (warps < ar_nvsize) warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) @@ -1777,12 +1707,10 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; assert(comm->sm_arch >= 9); - if (elements < 128) - return; + if (elements < 128) return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; + if (warps < ar_nvsize) warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) @@ -1823,12 +1751,10 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 64) - return; + if (elements < 64) return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; + if (warps < ar_nvsize) warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) @@ -1844,12 +1770,10 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 64) - return; + if (elements < 64) return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; + if (warps < ar_nvsize) warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { @@ -1883,12 +1807,10 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 64) - return; + if (elements < 64) return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; + if (warps < ar_nvsize) warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { @@ -1909,12 +1831,10 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 64) - return; + if (elements < 64) return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; + if (warps < ar_nvsize) warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { @@ -1941,12 +1861,10 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; assert(comm->sm_arch >= 9); - if (elements < 128) - return; + if (elements < 128) return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; + if (warps < ar_nvsize) warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) @@ -1982,16 +1900,13 @@ __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *f atomicAdd_system(flagptr, 1); } -__global__ void kuserbuffers_inc(int *id) { - atomicAdd(id, 1); -} +__global__ void kuserbuffers_inc(int *id) { atomicAdd(id, 1); } __global__ void kuserbuffers_dummy(void) {} __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pullrecv(int myrank, int peer, int nvrank, int nvpeer, int *recv_id, int *flagptr, - int4 *srcptr, int4 *dstptr, const int lines, - uint64_t ub_timeout) { + int4 *srcptr, int4 *dstptr, const int lines, uint64_t ub_timeout) { #define UNROLLCOPY 8 const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; const int end_elem = lines; @@ -2004,8 +1919,10 @@ __global__ void __launch_bounds__(MAX_THREADS) clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { - UB_PRINT("pullrecv [grank dst:%d global src:%d][nvrank(GPU) dst: %d src: %d]: expecting %d," - " observed %d", myrank, peer, nvrank, nvpeer, signal_id, *flag); + UB_PRINT( + "pullrecv [grank dst:%d global src:%d][nvrank(GPU) dst: %d src: %d]: expecting %d," + " observed %d", + myrank, peer, nvrank, nvpeer, signal_id, *flag); break; } } @@ -2016,17 +1933,14 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (end_elem <= start_elem) - return; + if (end_elem <= start_elem) return; for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { int4 val[UNROLLCOPY]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) - val[i] = srcptr[line + i * blockDim.x * gridDim.x]; + for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) - dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; } for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) dstptr[line] = srcptr[line]; @@ -2044,18 +1958,15 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { int4 val[UNROLLCOPY]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) - val[i] = srcptr[line + i * blockDim.x * gridDim.x]; + for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) - dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; } for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) dstptr[line] = srcptr[line]; } __syncthreads(); - if (threadIdx.x) - return; + if (threadIdx.x) return; __threadfence_system(); atomicAdd_system(flagptr, 1); // otherwise need local SM sync before sending flag @@ -2064,8 +1975,8 @@ __global__ void __launch_bounds__(MAX_THREADS) } } -#define CHECK_CE(ce_start, ce_end) ((ce_start) != nullptr && (ce_end) != nullptr && \ - *(ce_start) != *(ce_end)) +#define CHECK_CE(ce_start, ce_end) \ + ((ce_start) != nullptr && (ce_end) != nullptr && *(ce_start) != *(ce_end)) __global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpeer, int *recv_id, int *flagptr, int adder, uint64_t ub_timeout, @@ -2073,16 +1984,17 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpe const int signal_id = (*recv_id) + adder; *recv_id = signal_id; volatile int *flag = (volatile int *)flagptr; - if (*flag >= signal_id) - return; + if (*flag >= signal_id) return; clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { - UB_PRINT("pushrecv [grank dst:%d global src:%d][nvrank(GPU) dst: %d src: %d]: " - "expecting %d, observed %d", myrank, peer, nvrank, nvpeer, signal_id, *flag); + UB_PRINT( + "pushrecv [grank dst:%d global src:%d][nvrank(GPU) dst: %d src: %d]: " + "expecting %d, observed %d", + myrank, peer, nvrank, nvpeer, signal_id, *flag); if (CHECK_CE(ce_start_ptr, ce_end_ptr)) - UB_PRINT("pushrecv: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", - *ce_start_ptr, *ce_end_ptr); + UB_PRINT("pushrecv: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", *ce_start_ptr, + *ce_end_ptr); return; } } @@ -2091,8 +2003,8 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpe __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv(int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr, const int lines, int send_peer, int recv_peer, int *recv_id, - int *recv_flagptr, int adder, uint64_t ub_timeout, - int nv_send, int nv_recv, int *ce_start_ptr, int *ce_end_ptr) { + int *recv_flagptr, int adder, uint64_t ub_timeout, int nv_send, + int nv_recv, int *ce_start_ptr, int *ce_end_ptr) { if (lines) { const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; const int end_elem = lines; @@ -2116,8 +2028,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } } __syncthreads(); - if (threadIdx.x) - return; + if (threadIdx.x) return; __threadfence_system(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag @@ -2129,17 +2040,17 @@ __global__ void __launch_bounds__(MAX_THREADS) const int signal_id = (*recv_id) + adder; *recv_id = signal_id; volatile int *flag = (volatile int *)recv_flagptr; - if (*flag >= signal_id) - return; + if (*flag >= signal_id) return; clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { - UB_PRINT("pushsendrecv [sending peer:%d receiving peer:%d][nvrank(GPU) sending peer: %d" - " receiving peer: %d]: expecting %d, observed %d", - send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); + UB_PRINT( + "pushsendrecv [sending peer:%d receiving peer:%d][nvrank(GPU) sending peer: %d" + " receiving peer: %d]: expecting %d, observed %d", + send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); if (CHECK_CE(ce_start_ptr, ce_end_ptr)) - UB_PRINT("pushrecv: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", - *ce_start_ptr, *ce_end_ptr); + UB_PRINT("pushrecv: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", *ce_start_ptr, + *ce_end_ptr); return; } } @@ -2175,8 +2086,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } } __syncthreads(); - if (threadIdx.x) - return; + if (threadIdx.x) return; __threadfence_system(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag @@ -2191,9 +2101,10 @@ __global__ void __launch_bounds__(MAX_THREADS) clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { - UB_PRINT("pushsendrecv atomic [sending peer:%d receiving peer:%d][nvrank(GPU) sending peer:" - " %d receiving peer: %d]: expecting %d, observed %d", - send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); /*return;*/ + UB_PRINT( + "pushsendrecv atomic [sending peer:%d receiving peer:%d][nvrank(GPU) sending peer:" + " %d receiving peer: %d]: expecting %d, observed %d", + send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); /*return;*/ if (CHECK_CE(ce_start_ptr, ce_end_ptr)) UB_PRINT("pushsendrecv atomic: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", *ce_start_ptr, *ce_end_ptr); @@ -2208,13 +2119,10 @@ __global__ void __launch_bounds__(MAX_THREADS) } } -__global__ void __launch_bounds__(MAX_THREADS) - kuserbuffers_pushsendrecv_multiatomic(int *send_id, int *send_flagptr, int4 *srcptr, - int4 *dstptr, const int lines, int send_peer, - int recv_peer, int *recv_id, int *recv_flagptr, int adder, - void *counters, int nchunks, int send_stride, - int recv_stride, bool shuffle, - uint64_t ub_timeout, int nv_send, int nv_recv) { +__global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiatomic( + int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr, const int lines, int send_peer, + int recv_peer, int *recv_id, int *recv_flagptr, int adder, void *counters, int nchunks, + int send_stride, int recv_stride, bool shuffle, uint64_t ub_timeout, int nv_send, int nv_recv) { for (int chunk_i = 0; chunk_i < nchunks - 1; chunk_i++) { int send_chunk_id = shuffle ? chunk_i : (nchunks + send_peer - chunk_i) % nchunks; int recv_chunk_id = shuffle ? chunk_i + 1 : (nchunks + send_peer - chunk_i - 1) % nchunks; @@ -2262,9 +2170,10 @@ __global__ void __launch_bounds__(MAX_THREADS) clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { if (CHECK_TIMEOUT(s, ub_timeout)) { - UB_PRINT("pushsendrecv multiatomic [sending peer:%d receiving peer:%d][nvrank(GPU)" - " sending peer: %d receiving peer: %d]: expecting %d, observed %d", - send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); /*return;*/ + UB_PRINT( + "pushsendrecv multiatomic [sending peer:%d receiving peer:%d][nvrank(GPU)" + " sending peer: %d receiving peer: %d]: expecting %d, observed %d", + send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); /*return;*/ // CE mode is not supported for multi-atomic, so there is no need to check for a deadlock return; } @@ -2290,13 +2199,13 @@ __global__ void __launch_bounds__(MAX_THREADS) } } -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) // Return TRUE if two ranks share the same NV domain @@ -2306,37 +2215,34 @@ __global__ void __launch_bounds__(MAX_THREADS) // 0 - Send index counter // 1 - CE start index counter // 2 - CE end index counter -#define GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsth, index) \ - ((reinterpret_cast((comm)->peer_ptr[0][(peerlocal)])) + \ - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + \ - (comm)->myrank * NVTE_MAX_REGIONS + (dsth) + \ - (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ +#define GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsth, index) \ + ((reinterpret_cast((comm)->peer_ptr[0][(peerlocal)])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (comm)->myrank * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ sizeof(int))) // Index corresponds to the type of flag: // 0 - Receive index counter // 1 - CE start index counter // 2 - CE end index counter -#define GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsth, index) \ - ((reinterpret_cast((comm)->mem_ptr[0])) + \ - ((NVTE_REG0_OFFSET(comm) + \ - NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + \ - (dsth) + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ +#define GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsth, index) \ + ((reinterpret_cast((comm)->mem_ptr[0])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ sizeof(int))) void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, const int peer, cudaStream_t stream) { - int peerlocal = peer % comm->nvsize; - void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0); + int peerlocal = peer % comm->nvsize; + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0); void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1); - void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2); - bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2); + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); assert(INTRANODE(peer)); - if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) - return; + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; if (comm->push == 0) { kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), reinterpret_cast(flagptr)); @@ -2367,14 +2273,14 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); int send_peerlocal = send_peer % comm->nvsize; int recv_peerlocal = recv_peer % comm->nvsize; - void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); + void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); - void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); + void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); void *send_srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + send_offset; - void *send_dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) - + send_offset; + void *send_dstptr = + reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; if (comm->use_ce) { kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); @@ -2396,17 +2302,15 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size uint64_t arg11 = comm->ub_timeout; int arg12 = send_peerlocal; int arg13 = recv_peerlocal; - int *arg14 = reinterpret_cast(comm->use_ce ? - GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1): - nullptr); - int *arg15 = reinterpret_cast(comm->use_ce ? - GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2): - nullptr); - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), - reinterpret_cast(&arg3), reinterpret_cast(&arg4), - reinterpret_cast(&arg5), reinterpret_cast(&arg6), - reinterpret_cast(&arg7), reinterpret_cast(&arg8), - reinterpret_cast(&arg9), reinterpret_cast(&arg10), + int *arg14 = reinterpret_cast( + comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) : nullptr); + int *arg15 = reinterpret_cast( + comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) : nullptr); + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5), reinterpret_cast(&arg6), + reinterpret_cast(&arg7), reinterpret_cast(&arg8), + reinterpret_cast(&arg9), reinterpret_cast(&arg10), reinterpret_cast(&arg11), reinterpret_cast(&arg12), reinterpret_cast(&arg13), reinterpret_cast(&arg14), reinterpret_cast(&arg15)}; @@ -2423,14 +2327,14 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, int send_peerlocal = send_peer % comm->nvsize; int recv_peerlocal = recv_peer % comm->nvsize; - void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); + void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); - void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); - void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); + void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); + void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); void *send_srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + send_offset; - void *send_dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) - + send_offset; + void *send_dstptr = + reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; if (comm->use_ce) { kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); @@ -2452,17 +2356,15 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, int arg12 = comm->ub_timeout; int arg13 = send_peerlocal; int arg14 = recv_peerlocal; - int *arg15 = reinterpret_cast(comm->use_ce ? - GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) : - nullptr); - int *arg16 = reinterpret_cast(comm->use_ce ? - GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) : - nullptr); - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), - reinterpret_cast(&arg3), reinterpret_cast(&arg4), - reinterpret_cast(&arg5), reinterpret_cast(&arg6), - reinterpret_cast(&arg7), reinterpret_cast(&arg8), - reinterpret_cast(&arg9), reinterpret_cast(&arg10), + int *arg15 = reinterpret_cast( + comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) : nullptr); + int *arg16 = reinterpret_cast( + comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) : nullptr); + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5), reinterpret_cast(&arg6), + reinterpret_cast(&arg7), reinterpret_cast(&arg8), + reinterpret_cast(&arg9), reinterpret_cast(&arg10), reinterpret_cast(&arg11), reinterpret_cast(&arg12), reinterpret_cast(&arg13), reinterpret_cast(&arg14), reinterpret_cast(&arg15), reinterpret_cast(&arg16)}; @@ -2513,30 +2415,28 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler reinterpret_cast(&arg15), reinterpret_cast(&arg16), reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; CUDACHECK(cudaLaunchKernelExC( - &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); + &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); } void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, const int peer, cudaStream_t stream) { - int peerlocal = peer % comm->nvsize; - void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 0); + int peerlocal = peer % comm->nvsize; + void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 0); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); assert(INTRANODE(peer)); - if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) - return; + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; if (comm->push == 0) { void *dstptr = reinterpret_cast(comm->mem_ptr[dsthandler]) + dstoffset; void *srcptr = reinterpret_cast(comm->peer_ptr[srchandler][peerlocal]) + srcoffset; kuserbuffers_pullrecv<<sms, signalonly ? 1 : 1024, 0, stream>>>( - comm->myrank, peer, comm->nvrank, - peerlocal, &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), - reinterpret_cast(flagptr), reinterpret_cast(srcptr), - reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16, - comm->ub_timeout); + comm->myrank, peer, comm->nvrank, peerlocal, + &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast(flagptr), + reinterpret_cast(srcptr), reinterpret_cast(dstptr), + signalonly ? 0 : bytes / 16, comm->ub_timeout); if (!signalonly) kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); if (comm->use_ce) { @@ -2545,13 +2445,12 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds } else { kuserbuffers_pushrecv<<<1, 1, 0, stream>>>( comm->myrank, peer, comm->nvrank, peerlocal, - &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], - reinterpret_cast(flagptr), signalonly || comm->sms, - comm->ub_timeout, - reinterpret_cast(comm->use_ce ? - GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1) : nullptr), - reinterpret_cast(comm->use_ce ? - GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) : nullptr)); + &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], reinterpret_cast(flagptr), + signalonly || comm->sms, comm->ub_timeout, + reinterpret_cast(comm->use_ce ? GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1) + : nullptr), + reinterpret_cast(comm->use_ce ? GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) + : nullptr)); } } @@ -2612,31 +2511,33 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr template __global__ void __launch_bounds__(MAX_THREADS / 4) -reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, - const int num_inputs, const int input_size) { + reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, + const int num_inputs, const int input_size) { const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; fp8type *inputs_fp8 = reinterpret_cast(inputs); float accum_buf = static_cast(inputs_fp8[tid]) * (*scale); - #pragma unroll +#pragma unroll for (int i = 1; i < num_inputs; i++) { accum_buf += static_cast(inputs_fp8[tid + input_size * i]) * (*scale); } half *output_half = reinterpret_cast(output); - output_half[tid] = (half) accum_buf; + output_half[tid] = (half)accum_buf; } template void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream) { size_t num_threads = MAX_THREADS / 4; - size_t num_blocks = (input_size +num_threads - 1) / num_threads; + size_t num_blocks = (input_size + num_threads - 1) / num_threads; dim3 block(num_threads); dim3 grid(num_blocks); - reduce_fp8_in_bf16_out_cuda<<>>( - inputs, output, scale, num_inputs, input_size); + reduce_fp8_in_bf16_out_cuda + <<>>(inputs, output, scale, num_inputs, input_size); } -template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>( - void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); -template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>( - void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); +template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, + int num_inputs, int input_size, + cudaStream_t stream); +template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale, + int num_inputs, int input_size, + cudaStream_t stream); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h index 5c63351ee7472a00672269689cdb173fc6e14153..e8dbf978237ef1fb8bd8832d221ef009f99132ac 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h @@ -10,26 +10,28 @@ #include #include #include + #include -#include #include +#include #ifdef UB_MPI_BOOTSTRAP -#include #include -#define UB_MPI_CHECK(expr) \ - do { \ - const int mpicode = (expr); \ - if (mpicode != MPI_SUCCESS) { \ - char mpimsg[MPI_MAX_ERROR_STRING]; \ - int mpilen; \ - MPI_Error_string(mpicode, mpimsg, &mpilen); \ - std::vector errmsg(1024); \ - snprintf(errmsg.data(), errmsg.size(), "%s:%s in function %s: %s", \ - __FILE__, __LINE__, __func__, mpimsg); \ - throw std::runtime_error(errmsg.data()); \ - } \ +#include + +#define UB_MPI_CHECK(expr) \ + do { \ + const int mpicode = (expr); \ + if (mpicode != MPI_SUCCESS) { \ + char mpimsg[MPI_MAX_ERROR_STRING]; \ + int mpilen; \ + MPI_Error_string(mpicode, mpimsg, &mpilen); \ + std::vector errmsg(1024); \ + snprintf(errmsg.data(), errmsg.size(), "%s:%s in function %s: %s", __FILE__, __LINE__, \ + __func__, mpimsg); \ + throw std::runtime_error(errmsg.data()); \ + } \ } while (false) typedef MPI_Comm ExtComm; @@ -39,24 +41,15 @@ void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbyt UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank)); UB_MPI_CHECK(MPI_Comm_size(comm, &nranks)); *globaldata = malloc(nranks * localbytes); - UB_MPI_CHECK(MPI_Allgather(localdata, - localbytes, - MPI_BYTE, - *globaldata, - nranks * localbytes, - MPI_BYTE, - comm)); + UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE, *globaldata, nranks * localbytes, + MPI_BYTE, comm)); } -void ub_barrier(ExtComm comm) { - UB_MPI_CHECK(MPI_Barrier(comm)); -} +void ub_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } -void ub_free(void *ptr) { - free(ptr); -} +void ub_free(void *ptr) { free(ptr); } #else -typedef char* ExtComm; +typedef char *ExtComm; #endif #define NVTE_MAX_REGIONS 16 @@ -80,8 +73,8 @@ typedef char* ExtComm; #define NVTE_REG0_OPFLAGS 1024 #define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types) #define NVTE_REG0_SINGLENODE (2 * NVTE_MAX_NVLINK * NVTE_MAX_SMS + NVTE_MAX_OPS) -#define NVTE_REG0_OFFSET(comm) ((2 * NVTE_MAX_REGIONS) * NVTE_MAX_NVLINK \ - + NVTE_REG0_SINGLENODE * 2 + NVTE_MAX_PEERS) +#define NVTE_REG0_OFFSET(comm) \ + ((2 * NVTE_MAX_REGIONS) * NVTE_MAX_NVLINK + NVTE_REG0_SINGLENODE * 2 + NVTE_MAX_PEERS) #define NVTE_REG0_COMMBUFFER 0 // x3 for [flagptr, ce_start_ptr, ce_end_ptr] #define NVTE_REG0_FLAGS (NVTE_REG0_RECV + NVTE_MAX_PEERS * NVTE_MAX_REGIONS * 3) @@ -90,7 +83,7 @@ typedef char* ExtComm; #if defined(UCP) || !defined(NOSHARP) #undef REG0_COMMBUFFER -#define REG0_COMMBUFFER (1024*1024*16) +#define REG0_COMMBUFFER (1024 * 1024 * 16) #endif // gpuflags map offsets #define NVTE_GF_STATE 16000 @@ -142,12 +135,12 @@ struct communicator { int memflags[NVTE_MAX_REGIONS]; // UC,MC, user/lib allocated CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS]; - void* ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory + void *ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory size_t mem_size[NVTE_MAX_REGIONS]; bool mem_dealloc[NVTE_MAX_REGIONS]; - void* mc_ptr[NVTE_MAX_REGIONS]; - void* mc_baseptr; + void *mc_ptr[NVTE_MAX_REGIONS]; + void *mc_baseptr; CUmemGenericAllocationHandle mc_handle; size_t mc_offset, mc_maxsize; int use_mc; // 1: use MC if available, 0: override not to use MC @@ -177,13 +170,13 @@ struct communicator { volatile int tail; // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) - std::function _alloc_copy_allgather; + std::function _alloc_copy_allgather; std::function _barrier; - std::function _free; + std::function _free; ExtComm comm_world, - comm_inter, // reduction group communicator (subset of the nodes) along GPU rail - comm_intra; // full intranode (all ndev GPUS) + comm_inter, // reduction group communicator (subset of the nodes) along GPU rail + comm_intra; // full intranode (all ndev GPUS) #ifdef UB_MPI_BOOTSTRAP MPI_Request mpihndl[NVTE_MAX_SHARP]; #endif @@ -199,28 +192,25 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream); /* creates communicator, allocates all internal buffers if necessary */ -int create_communicator_grouped2(communicator **comm, - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, - std::function ext_alloc_copy_allgather, - std::function ext_barrier, - std::function ext_free, - int pipegpus, int pipenodes, int tensorgpus, int tensornodes); - -int create_communicator_grouped(communicator **comm, - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, - std::function ext_alloc_copy_allgather, - std::function ext_barrier, - std::function ext_free, - int pipegpus, int pipenodes); - -int create_communicator(communicator **comm, - int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, - std::function ext_alloc_copy_allgather, - std::function ext_barrier, - std::function ext_free); - -int create_communicator_grouped2_mpi(communicator **comm, - int pipegpus, int pipenodes, int tensorgpus, int tensornodes); +int create_communicator_grouped2( + communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, std::function ext_alloc_copy_allgather, + std::function ext_barrier, std::function ext_free, int pipegpus, + int pipenodes, int tensorgpus, int tensornodes); + +int create_communicator_grouped( + communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, std::function ext_alloc_copy_allgather, + std::function ext_barrier, std::function ext_free, int pipegpus, + int pipenodes); + +int create_communicator( + communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, std::function ext_alloc_copy_allgather, + std::function ext_barrier, std::function ext_free); + +int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, + int tensorgpus, int tensornodes); int create_communicator_grouped_mpi(communicator **comm, int pipegpus, int pipenodes); @@ -273,37 +263,40 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons const int rowelements, const int colelements, const int strideelements, communicator *comm, cudaStream_t stream = 0); -template -void reducescatter2_userbuff_stridedoutput_fp8(void* output, float* scale, const int handler, +template +void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator* comm, cudaStream_t stream = 0); -template -void reducescatter2_userbuff_fp8(void* output, float* scale, const int handler, const int offset, - const int elements, communicator* comm, cudaStream_t stream = 0); -template -void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler, + communicator *comm, cudaStream_t stream = 0); +template +void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, + const int elements, communicator *comm, cudaStream_t stream = 0); +template +void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements_out, const int strideelements_in, const int numchunks, - void *counters, communicator* comm, + void *counters, communicator *comm, cudaStream_t stream = 0); -template +template void reducescatter2_userbuff_strided_multiatomic_fp8( - void* output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements_out, const int strideelements_in, - const int numchunks, void *counters, communicator* comm, cudaStream_t stream = 0); -void reducescatter2_userbuff_strided( - void* output, const int handler, const int offset, const int rowelements, const int colelements, - const int strideelements, communicator* comm, cudaStream_t stream = 0); -void reducescatter2_userbuff_strided_atomic( - void* output, const int handler , const int offset, const int rowelements, const int colelements, - const int strideelements, const int numchunks, void *counters, communicator* comm, - cudaStream_t stream = 0); -void reducescatter2_userbuff_strided_multiatomic( - void* output, const int handler, const int offset, const int rowelements, const int colelements, - const int strideelements, const int numchunks, void *counters, communicator* comm, - cudaStream_t stream = 0); + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, cudaStream_t stream = 0); +void reducescatter2_userbuff_strided(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, communicator *comm, + cudaStream_t stream = 0); +void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, const int numchunks, + void *counters, communicator *comm, + cudaStream_t stream = 0); +void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, const int numchunks, + void *counters, communicator *comm, + cudaStream_t stream = 0); /* everything should be 16byte aligned = 8 elts aligned output is strided: row starts separated by stride elements*/ @@ -321,19 +314,18 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, const int peer, cudaStream_t stream = 0); -void userbuffers_sendrecv( - const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset, - const size_t bytes, communicator* comm, const int send_peer, const int recv_peer, - cudaStream_t stream = 0); -void userbuffers_sendrecv_atomic( - const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset, - const size_t bytes, communicator* comm, const int send_peer, const int recv_peer, void *counters, - cudaStream_t stream = 0); -void userbuffers_sendrecv_multiatomic( - const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset, - const size_t bytes, communicator* comm, const int send_peer, const int recv_peer, - const int nchunks, void *counters, bool shuffle, cudaStream_t stream = 0); - +void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset, + const size_t recv_offset, const size_t bytes, communicator *comm, + const int send_peer, const int recv_peer, cudaStream_t stream = 0); +void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, + const size_t send_offset, const size_t recv_offset, + const size_t bytes, communicator *comm, const int send_peer, + const int recv_peer, void *counters, cudaStream_t stream = 0); +void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler, + const size_t send_offset, const size_t recv_offset, + const size_t bytes, communicator *comm, const int send_peer, + const int recv_peer, const int nchunks, void *counters, + bool shuffle, cudaStream_t stream = 0); // alltoall split send and recv to allow for overlap // send kicks in sending data to the destination - invoke on same stream as data generation @@ -349,7 +341,7 @@ void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream = 0); void destroy_communicator(communicator *comm); template -void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, - int input_size, cudaStream_t stream); +void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size, + cudaStream_t stream); #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 0a44f52100dff789e1d4fa8ff0773fd0be3eb503..a2c5620f361316ce21aa46fcc398b3164fba14b7 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -51,10 +51,12 @@ def set_all_rng_states(states: List) -> None: def graph_safe_rng_available() -> bool: """Returns whether cuda graph safe RNG state manipulation is supported.""" - return (hasattr(torch.cuda.CUDAGraph, "register_generator_state") - and hasattr(torch.Generator, "graphsafe_set_state") - and hasattr(torch.Generator, "graphsafe_get_state") - and hasattr(torch.Generator, "clone_state")) + return ( + hasattr(torch.cuda.CUDAGraph, "register_generator_state") + and hasattr(torch.Generator, "graphsafe_set_state") + and hasattr(torch.Generator, "graphsafe_get_state") + and hasattr(torch.Generator, "clone_state") + ) def _get_cuda_rng_state( @@ -85,7 +87,7 @@ def _get_cuda_rng_state( def _set_cuda_rng_state( new_state: torch.Tensor, device: Union[int, str] = -1, - graph_safe = True, + graph_safe=True, ) -> None: """Sets the random number generator state of the current GPU.""" @@ -177,9 +179,7 @@ def split_tensor_into_1d_equal_chunks( return data -def gather_split_1d_tensor( - tensor: torch.Tensor, tp_group: dist_group_type -) -> torch.Tensor: +def gather_split_1d_tensor(tensor: torch.Tensor, tp_group: dist_group_type) -> torch.Tensor: """Opposite of above function, gather values from model parallel ranks.""" numel_gathered = torch.numel(tensor) * get_distributed_world_size(tp_group) gathered = torch.empty( @@ -200,11 +200,8 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator): retrieved, and the forward pass is computed again while tracking the intermediate activations, followed by calculation of gradients using these values. """ - def __init__( - self, - activation_recompute: bool = False, - recompute_phase: bool = False - ): + + def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False): super().__init__() self.activation_recompute = activation_recompute self.recompute_phase = recompute_phase @@ -242,12 +239,14 @@ def _get_active_autocast_contexts(): gpu_autocast_enabled = torch.is_autocast_enabled() gpu_autocast_dtype = torch.get_autocast_gpu_dtype() gpu_autocast_ctx = torch.cuda.amp.autocast( - gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached) + gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached + ) cpu_autocast_enabled = torch.is_autocast_cpu_enabled() cpu_autocast_dtype = torch.get_autocast_cpu_dtype() cpu_autocast_ctx = torch.cpu.amp.autocast( - cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached) + cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached + ) return gpu_autocast_ctx, cpu_autocast_ctx @@ -291,9 +290,7 @@ class _CheckpointFunction(torch.autograd.Function): torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts() with torch.no_grad(), forward_ctx: - with activation_recompute_forward( - activation_recompute=True, recompute_phase=False - ): + with activation_recompute_forward(activation_recompute=True, recompute_phase=False): outputs = run_function(*args, **kwargs) # Divide hidden states across model parallel group and only keep @@ -302,9 +299,7 @@ class _CheckpointFunction(torch.autograd.Function): ctx.input_0_shape = args[0].data.shape safely_set_viewless_tensor_data( args[0], - split_tensor_into_1d_equal_chunks( - args[0].data, tp_group, new_buffer=True - ), + split_tensor_into_1d_equal_chunks(args[0].data, tp_group, new_buffer=True), ) # Store everything. @@ -328,13 +323,11 @@ class _CheckpointFunction(torch.autograd.Function): """Call backward function with activation recomputation.""" if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( - "Checkpointing is not compatible with .grad(), " - "please use .backward() if possible" + "Checkpointing is not compatible with .grad(), please use .backward() if possible" ) inputs = tuple( - t if t is not None else arg - for (t, arg) in zip(ctx.saved_tensors, ctx.inputs) + t if t is not None else arg for (t, arg) in zip(ctx.saved_tensors, ctx.inputs) ) get_rng_state_tracker = ctx.get_rng_state_tracker @@ -342,9 +335,7 @@ class _CheckpointFunction(torch.autograd.Function): if ctx.distribute_saved_activations: safely_set_viewless_tensor_data( inputs[0], - gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view( - ctx.input_0_shape - ), + gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(ctx.input_0_shape), ) # Store the current states. @@ -361,10 +352,13 @@ class _CheckpointFunction(torch.autograd.Function): # Compute the forward pass. detached_inputs = detach_variable(inputs) - with (torch.enable_grad(), ctx.recompute_ctx, - ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, - activation_recompute_forward( - activation_recompute=True, recompute_phase=True)): + with ( + torch.enable_grad(), + ctx.recompute_ctx, + ctx.torch_gpu_amp_ctx, + ctx.torch_cpu_amp_ctx, + activation_recompute_forward(activation_recompute=True, recompute_phase=True), + ): outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) # Set the states back to what it was at the start of this function. @@ -384,14 +378,12 @@ class _CheckpointFunction(torch.autograd.Function): args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: raise RuntimeError( - "none of output has requires_grad=True," - " this checkpoint() is not necessary" + "none of output has requires_grad=True, this checkpoint() is not necessary" ) torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple( - inp.grad if isinstance(inp, torch.Tensor) else None - for inp in detached_inputs + inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs ) return (None, None, None, None, None, None) + grads @@ -400,11 +392,8 @@ class _CheckpointFrame: """ Storage frame for forward RNG states and detached activations from the forward recompute. """ - def __init__( - self, - recompute_fn: Callable, - get_rng_state_tracker: Callable - ): + + def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable): self.recompute_fn = recompute_fn self.recomputed = [] self.count = 0 @@ -412,7 +401,6 @@ class _CheckpointFrame: self.fwd_rng_states = None self.bwd_rng_states = None - def cache_rng_states(self, forward=True): """Cache fwd/bwd RNG states in the frame to restore later.""" rng_states = ( @@ -420,7 +408,7 @@ class _CheckpointFrame: _get_cuda_rng_state(graph_safe=False), ) if self.get_rng_state_tracker is not None: - rng_states += (self.get_rng_state_tracker().get_states(), ) + rng_states += (self.get_rng_state_tracker().get_states(),) if forward: self.fwd_rng_states = rng_states @@ -440,7 +428,9 @@ class _CheckpointFrame: self.get_rng_state_tracker().set_states(rng_states[2]) -class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): # pylint: disable=too-few-public-methods +class _recomputation_hook( + torch.autograd.graph.saved_tensors_hooks +): # pylint: disable=too-few-public-methods """torch.autograd hook for packing/unpacking tensors during the activation recompute phase.""" def __init__(self, frame): @@ -463,7 +453,9 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): # pylint: super().__init__(pack_hook, unpack_hook) -class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): # pylint: disable=too-few-public-methods +class _checkpoint_hook( + torch.autograd.graph.saved_tensors_hooks +): # pylint: disable=too-few-public-methods """torch.autograd hook for packing/unpacking tensors during the checkpointed forward pass.""" def __init__(self, frame, args, kwargs): @@ -557,6 +549,7 @@ def has_te_modules(network): # so just assume that it has TE modules just to be safe. return True + @torch._disable_dynamo def checkpoint( function: Callable, @@ -607,13 +600,18 @@ def checkpoint( get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None) # Ensure backward compatibility. - if (len(args) > 3 and isinstance(args[0], bool) and callable(args[1]) - and isinstance(args[2], None | dist_group_type)): + if ( + len(args) > 3 + and isinstance(args[0], bool) + and callable(args[1]) + and isinstance(args[2], None | dist_group_type) + ): warnings.warn( "Passing non-tensor non-keyword arguments is deprecated and support will be removed in " "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and " "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.", - DeprecationWarning, stacklevel=2, + DeprecationWarning, + stacklevel=2, ) distribute_saved_activations = args[0] get_rng_state_tracker = args[1] @@ -633,7 +631,7 @@ def checkpoint( context_fn=context_fn, determinism_check=determinism_check, debug=debug, - **kwargs + **kwargs, ) # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need @@ -680,9 +678,13 @@ def checkpoint( torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts() def recompute_fn(*args, **kwargs): - with (torch.autograd.enable_grad(), - te_recompute_ctx, user_recompute_ctx, - torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx): + with ( + torch.autograd.enable_grad(), + te_recompute_ctx, + user_recompute_ctx, + torch_gpu_amp_forward_ctx, + torch_cpu_amp_forward_ctx, + ): function(*args, **kwargs) # Initialize a new checkpoint frame for each new forward pass. @@ -692,8 +694,7 @@ def checkpoint( ) new_frame.cache_rng_states(forward=True) - with (_checkpoint_hook(new_frame, args, kwargs), - te_forward_ctx, user_forward_ctx): + with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx: out = function(*args, **kwargs) return out @@ -820,9 +821,7 @@ def reduce_scatter_along_first_dim( dim_size[0] = dim_size[0] // world_size - output = torch.empty( - dim_size, dtype=input_.dtype, device=torch.cuda.current_device() - ) + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=tp_group, async_op=async_op ) @@ -842,9 +841,7 @@ def gather_along_first_dim( dim_size = list(input_.size()) dim_size[0] = dim_size[0] * world_size - output = torch.empty( - dim_size, dtype=input_.dtype, device=torch.cuda.current_device() - ) + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) handle = torch.distributed.all_gather_into_tensor( output, input_.contiguous(), group=tp_group, async_op=async_op ) @@ -880,8 +877,8 @@ def _fsdp_scatter_tensors( target = t._data if isinstance(t, Float8Tensor) else t shapes.append(target.data.shape) safely_set_viewless_tensor_data( - target, split_tensor_into_1d_equal_chunks( - target.data, fsdp_group, new_buffer=True) + target, + split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True), ) else: shapes.append(None) @@ -890,7 +887,7 @@ def _fsdp_scatter_tensors( def _fsdp_gather_tensors( fsdp_group: dist_group_type, - shapes: List[Tuple[int,...]], + shapes: List[Tuple[int, ...]], *tensors: torch.Tensor, ): if fsdp_group is not None: @@ -913,6 +910,7 @@ def _is_te_module(module): from .module.base import TransformerEngineBaseModule from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention from .transformer import TransformerLayer + te_classes_list = [ LayerNorm, RMSNorm, diff --git a/transformer_engine/pytorch/export.py b/transformer_engine/pytorch/export.py index 2769c752cfad013579a756d4abf04105d758be01..5bc079711a9ef4a779d38e280c382e6da2964735 100755 --- a/transformer_engine/pytorch/export.py +++ b/transformer_engine/pytorch/export.py @@ -7,6 +7,7 @@ from contextlib import contextmanager _IN_ONNX_EXPORT_MODE = False + @contextmanager def onnx_export( enabled: bool = False, @@ -26,13 +27,14 @@ def onnx_export( """ global _IN_ONNX_EXPORT_MODE - onnx_export_state = (_IN_ONNX_EXPORT_MODE) + onnx_export_state = _IN_ONNX_EXPORT_MODE try: _IN_ONNX_EXPORT_MODE = enabled yield finally: _IN_ONNX_EXPORT_MODE = onnx_export_state + def is_in_onnx_export_mode() -> bool: """Returns True if onnx export mode is enabled, False otherwise.""" return _IN_ONNX_EXPORT_MODE diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index e486a84401790d4a0ae8746dd65fa4cfd92435a6..a38c88cf31b563a75d10d7db8b128fe7c321c1dc 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -34,17 +34,22 @@ def _make_fp8_attr_property_funcs(name: str) -> Any: Key in dictionary of FP8 attributes """ + def get_func(self) -> Any: return self._fp8_attrs[name] + def set_func(self, value: Any) -> None: self._fp8_attrs[name] = value + def del_func(self) -> None: del self._fp8_attrs[name] + return dict(fget=get_func, fset=set_func, fdel=del_func) class _FromFloat8Func(torch.autograd.Function): """Cast from FP8 to other dtype""" + @staticmethod def forward( _ctx: torch.autograd.function.FunctionCtx, # unused @@ -53,7 +58,7 @@ class _FromFloat8Func(torch.autograd.Function): ) -> torch.Tensor: if dtype is None: dtype = tensor.dtype - data = tensor._data.contiguous().view(1,-1).detach() + data = tensor._data.contiguous().view(1, -1).detach() out = tex.cast_from_fp8( data, tensor._scale_inv, @@ -92,13 +97,13 @@ def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] # All FP8 trainable parameters have been updated. if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors( - forward=True, fp8_weights=True) + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) del updated_fp8_params[autocast_key] class _ToFloat8Func(torch.autograd.Function): """Cast to FP8 from other dtype""" + @staticmethod def forward( _ctx: torch.autograd.function.FunctionCtx, # unused @@ -153,9 +158,7 @@ class _ToFloat8Func(torch.autograd.Function): device=tensor.device, ) if scale.numel() != 1: - raise ValueError( - "Attempted to initialize Float8Tensor with invalid scale tensor" - ) + raise ValueError("Attempted to initialize Float8Tensor with invalid scale tensor") scale = scale.to(device=tensor.device, dtype=torch.float32) # Check scale-inverse @@ -167,13 +170,11 @@ class _ToFloat8Func(torch.autograd.Function): if amax is None: amax = torch.empty_like(scale) if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32): - raise ValueError( - "Attempted to initialize Float8Tensor with invalid amax tensor" - ) + raise ValueError("Attempted to initialize Float8Tensor with invalid amax tensor") # Cast data to FP8 data = tex.cast_to_fp8( - tensor.view(1,-1), + tensor.view(1, -1), scale, amax, scale_inv, @@ -240,6 +241,7 @@ class _IdentityFunc(torch.autograd.Function): def backward(ctx, grad): return grad.to(ctx.input_dtype), None + class _ViewFunc(torch.autograd.Function): """View function @@ -268,7 +270,8 @@ class _ViewFunc(torch.autograd.Function): return tensor.view(*shape) @staticmethod - def backward(ctx, + def backward( + ctx, grad: torch.Tensor, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -309,7 +312,8 @@ class _ReshapeFunc(torch.autograd.Function): return tensor.reshape(*shape) @staticmethod - def backward(ctx, + def backward( + ctx, grad: torch.Tensor, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -375,13 +379,10 @@ class Float8Tensor(torch.Tensor): # Check that data buffer is valid if data.element_size() != 1: raise ValueError( - "Float8Tensor requires data buffer with 8-bit dtype " - f"(got dtype={data.dtype})" + f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" ) if data.requires_grad: - raise ValueError( - "Float8Tensor requires non-differentiable data buffer" - ) + raise ValueError("Float8Tensor requires non-differentiable data buffer") if not data.is_cuda: data = data.cuda() @@ -418,8 +419,9 @@ class Float8Tensor(torch.Tensor): self._fp8_meta_index: Optional[int] = fp8_meta_index # FP8 dtype - assert ( - fp8_dtype in (tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2) + assert fp8_dtype in ( + tex.DType.kFloat8E4M3, + tex.DType.kFloat8E5M2, ), f"Unsupported fp8_dtype {fp8_dtype}." self._fp8_dtype: tex.DType = fp8_dtype @@ -451,10 +453,7 @@ class Float8Tensor(torch.Tensor): ) if fp8_scale_inv.dim() != 1: fp8_scale_inv = fp8_scale_inv.reshape(1) - if ( - fp8_scale_inv.device != self._data.device - or fp8_scale_inv.dtype != torch.float32 - ): + if fp8_scale_inv.device != self._data.device or fp8_scale_inv.dtype != torch.float32: fp8_scale_inv = fp8_scale_inv.to( device=self._data.device, dtype=torch.float32, @@ -674,7 +673,6 @@ class Float8Tensor(torch.Tensor): tensor.device != data.device or tensor.dtype not in (torch.float32, torch.float16, torch.bfloat16) or not tensor.is_contiguous() - ): dtype = tensor.dtype if dtype not in (torch.float32, torch.float16, torch.bfloat16): @@ -691,13 +689,11 @@ class Float8Tensor(torch.Tensor): ) if not data.is_contiguous(): raise ValueError( - "FP8 cast-transpose is only supported for " - "`Float8Tensor`s with contiguous data" + "FP8 cast-transpose is only supported for `Float8Tensor`s with contiguous data" ) if self._fp8_meta is None: raise ValueError( - "FP8 cast-transpose is only supported for " - "`Float8Tensor`s with FP8 metadata " + "FP8 cast-transpose is only supported for `Float8Tensor`s with FP8 metadata " ) # Construct transpose cache if needed @@ -726,7 +722,7 @@ class Float8Tensor(torch.Tensor): transpose_out=transpose, noop_flag=noop_flag, ) - scale = fp8_meta.scale[fp8_meta_index:fp8_meta_index+1] + scale = fp8_meta.scale[fp8_meta_index : fp8_meta_index + 1] scale_inv = self._scale_inv if noop_flag is None: torch.reciprocal(scale, out=scale_inv) @@ -784,13 +780,9 @@ class Float8Tensor(torch.Tensor): dst = args[0] src = args[1] if not isinstance(dst, torch.Tensor): - raise RuntimeError( - "Attempted to copy into something that isn't a PyTorch tensor" - ) + raise RuntimeError("Attempted to copy into something that isn't a PyTorch tensor") if not isinstance(src, torch.Tensor): - raise RuntimeError( - "Attempted to copy from something that isn't a PyTorch tensor" - ) + raise RuntimeError("Attempted to copy from something that isn't a PyTorch tensor") # Special handling based on which tensors are FP8 dst_is_fp8 = isinstance(dst, Float8Tensor) @@ -850,9 +842,9 @@ class Float8Tensor(torch.Tensor): if not dst._data.is_contiguous(): raise RuntimeError("Transformer Engine cast kernels require contiguous data") tex.cast_to_fp8_noalloc( - src.view(1,-1), + src.view(1, -1), scale, - dst._data.view(1,-1), + dst._data.view(1, -1), amax, dst._scale_inv, dst._fp8_dtype, @@ -904,12 +896,12 @@ class Float8Tensor(torch.Tensor): Keep the same FP8 scaling factors. """ - if( - isinstance(arg, Float8Tensor) and - isinstance(new_arg, torch.Tensor) and - hasattr(schema_arg, 'alias_info') and - hasattr(schema_arg.alias_info, 'is_write') and - schema_arg.alias_info.is_write + if ( + isinstance(arg, Float8Tensor) + and isinstance(new_arg, torch.Tensor) + and hasattr(schema_arg, "alias_info") + and hasattr(schema_arg.alias_info, "is_write") + and schema_arg.alias_info.is_write ): arg.copy_(new_arg) arg._reset_caches() diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index bf8b0f007535c4cc3b5c7c5feae7995765e53ddb..e15268b998c2468a1980a683a8e5435c2297c688 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -22,9 +22,9 @@ __all__ = ["fp8_autocast", "fp8_model_init"] def check_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" - if get_device_compute_capability() >= (9, 0): # hopper and above + if get_device_compute_capability() >= (9, 0): # hopper and above return True, "" - if get_device_compute_capability() < (8, 9): # pre-ada + if get_device_compute_capability() < (8, 9): # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." if tex.get_cublasLt_version() < 120103: return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." @@ -38,9 +38,7 @@ def get_default_fp8_recipe() -> DelayedScaling: return DelayedScaling() -def get_fp8_te_dtype( - fp8_recipe: DelayedScaling, fprop_tensor: bool = True -) -> tex.DType: +def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -49,9 +47,7 @@ def get_fp8_te_dtype( return tex.DType.kFloat8E5M2 -def get_fp8_max( - fp8_recipe: DelayedScaling, fprop_tensor: bool = True -) -> tex.DType: +def get_fp8_max(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -64,6 +60,7 @@ class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. """ + FP8_ENABLED = False FP8_CALIBRATION = False FP8_RECIPE = None @@ -207,19 +204,22 @@ class FP8GlobalStateManager: if forward and fp8_weights is not None: autocast_key = cls.get_unique_autocast_key( - fp8_meta["recipe"], fp8_meta["fp8_group"]) + fp8_meta["recipe"], fp8_meta["fp8_group"] + ) fp8_weight_set = {id(w._data) for w in fp8_weights} if autocast_key not in cls.autocast_to_fp8_params: cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set else: - cls.autocast_to_fp8_params[autocast_key] = ( - cls.autocast_to_fp8_params[autocast_key].union(fp8_weight_set)) + cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[ + autocast_key + ].union(fp8_weight_set) # Identify correct autocast key for a given param. for w in fp8_weight_set: cls.fp8_param_to_autocast[w] = autocast_key key = cls.get_key_in_buffer( - forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]) + forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"] + ) if key not in cls.global_amax_buffer: cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] @@ -229,7 +229,8 @@ class FP8GlobalStateManager: else: cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append( - fp8_meta[fp8_meta_tensor_key].amax_history) + fp8_meta[fp8_meta_tensor_key].amax_history + ) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) @@ -283,25 +284,25 @@ class FP8GlobalStateManager: cls.FP8_RECIPE, cls.FP8_DISTRIBUTED_GROUP, cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING) + cls.FP8_GRAPH_CAPTURING, + ) @classmethod def set_fp8_autocast_state( - cls, - fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] + cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] ) -> None: """FP8 autocast state setter""" - (cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING) = fp8_state + ( + cls.FP8_ENABLED, + cls.FP8_CALIBRATION, + cls.FP8_RECIPE, + cls.FP8_DISTRIBUTED_GROUP, + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING, + ) = fp8_state @staticmethod - def reduce_tensor_across_group_op_max( - tensor: torch.Tensor, group: dist_group_type - ) -> None: + def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: """Reduce tensor across given group.""" if torch.distributed.is_initialized(): torch.distributed.all_reduce( @@ -337,15 +338,19 @@ class FP8GlobalStateManager: contiguous_amax = torch.cat(amax_buffer) # Reduction. - if (recipe.reduce_amax + if ( + recipe.reduce_amax and torch.distributed.is_initialized() - and torch.distributed.get_world_size(group=group) > 1): + and torch.distributed.get_world_size(group=group) > 1 + ): cls.reduce_tensor_across_group_op_max(contiguous_amax, group) # Amax and scale update. - unfused_update = (bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) - or callable(recipe.amax_compute_algo) - or callable(recipe.scaling_factor_compute_algo)) + unfused_update = ( + bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) + or callable(recipe.amax_compute_algo) + or callable(recipe.scaling_factor_compute_algo) + ) if not unfused_update: tex.fused_amax_and_scale_update_after_reduction( @@ -366,7 +371,8 @@ class FP8GlobalStateManager: cls.global_scale_inv_buffer[buffer_key], ): _amax_and_scale_update( - amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe) + amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe + ) @classmethod def get_unique_autocast_key( @@ -455,9 +461,7 @@ class FP8GlobalStateManager: # Retrieve stashed amaxes and scales from phase 1 pre forward. buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[ - fp8_meta[buffer_position_key] - ].popleft() + stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() # Replace amaxes and scales with stashed values for phase 2 forward fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] @@ -554,11 +558,13 @@ def fp8_autocast( are reduced at the end of each training step. """ fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() - FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, - calibrating=calibrating, - fp8_recipe=fp8_recipe, - fp8_group=fp8_group, - _graph=_graph) + FP8GlobalStateManager.fp8_autocast_enter( + enabled=enabled, + calibrating=calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group, + _graph=_graph, + ) try: yield finally: @@ -610,7 +616,7 @@ def _default_sf_compute( 4. When amax == inf or amax == nan: No action is possible, set scale to the previous scale (or 1). """ - sf = (fp8_max / amax) / (2 ** margin) + sf = (fp8_max / amax) / (2**margin) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 3f73b5306df6faad598bbdf030e122a6c45f678b..a6f62ac4577895ca2c7b5f893a981b168a21ba42 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -78,19 +78,16 @@ def _make_graphed_callables( num_model_chunks = max(_order) num_microbatches = len(_order) // num_model_chunks // 2 assert num_model_chunks * num_microbatches * 2 == len(_order) - assert ( - len(sample_args)*2 >= len(_order) - and (len(sample_args)*2 % len(_order) == 0) - ), f'{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0' + assert len(sample_args) * 2 >= len(_order) and ( + len(sample_args) * 2 % len(_order) == 0 + ), f"{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0" num_layers = len(sample_args) // num_model_chunks // num_microbatches - assert ( - len(callables) == num_model_chunks*num_layers - ), (f"Callables should have ({num_model_chunks * num_layers}) " + assert len(callables) == num_model_chunks * num_layers, ( + f"Callables should have ({num_model_chunks * num_layers}) " + f"entries when order input is provided but got {len(callables)}." ) - assert ( - len(sample_args) == num_model_chunks * num_microbatches * num_layers - ), (f"Expected {num_model_chunks * num_microbatches}" + assert len(sample_args) == num_model_chunks * num_microbatches * num_layers, ( + f"Expected {num_model_chunks * num_microbatches}" + f"args tuple, but got {len(sample_args)}." ) @@ -126,12 +123,10 @@ def _make_graphed_callables( per_callable_len_user_args = [len(args) for args in flatten_sample_args] if _order is None: per_callable_module_params = [ - tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () - for c in callables + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () for c in callables ] per_callable_static_input_surfaces = [ - flatten_sample_args[i] + per_callable_module_params[i] - for i in range(len(callables)) + flatten_sample_args[i] + per_callable_module_params[i] for i in range(len(callables)) ] else: per_callable_module_params = [] @@ -171,9 +166,7 @@ def _make_graphed_callables( grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple( - torch.empty_like(o) for o in outputs if o.requires_grad - ), + grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), only_inputs=True, allow_unused=allow_unused_input, ) @@ -184,7 +177,7 @@ def _make_graphed_callables( # the safest approach is to capture all passes in the same order they'll run: # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. - if _order is not None: # pylint: disable=too-many-nested-blocks + if _order is not None: # pylint: disable=too-many-nested-blocks per_callable_static_outputs = [None] * len(flatten_sample_args) per_callable_output_unflatten_spec = [None] * len(flatten_sample_args) per_callable_static_grad_outputs = [None] * len(flatten_sample_args) @@ -194,11 +187,12 @@ def _make_graphed_callables( for c_id in _order: if c_id > 0: # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] - m_chunk = c_id-1 + m_chunk = c_id - 1 for l_no in range(num_layers): - func = callables[m_chunk*num_layers + l_no] - per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) \ - + (fwd_idx[m_chunk] * num_layers + l_no) + func = callables[m_chunk * num_layers + l_no] + per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + ( + fwd_idx[m_chunk] * num_layers + l_no + ) args = sample_args[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] with torch.cuda.graph(fwd_graph, pool=mempool): @@ -210,10 +204,11 @@ def _make_graphed_callables( fwd_idx[m_chunk] += 1 else: # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] - m_chunk = -c_id-1 + m_chunk = -c_id - 1 for l_no in list(reversed(range(num_layers))): - per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) \ - + (bwd_idx[m_chunk] * num_layers + l_no) + per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) + ( + bwd_idx[m_chunk] * num_layers + l_no + ) static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx] static_outputs = per_callable_static_outputs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx] @@ -314,6 +309,7 @@ def _make_graphed_callables( ): class Graphed(torch.autograd.Function): """Autograd function for graph replay.""" + @staticmethod def forward(ctx, skip_fp8_weight_update, *inputs): # At this stage, only the user args may (potentially) be new tensors. @@ -356,9 +352,8 @@ def _make_graphed_callables( # Assumes module params didn't change since capture. skip_fp8_weight_update = None if fp8_weight_caching: - assert ( - ("is_first_microbatch" in user_kwargs - and isinstance(user_kwargs["is_first_microbatch"], bool)) + assert "is_first_microbatch" in user_kwargs and isinstance( + user_kwargs["is_first_microbatch"], bool ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] @@ -394,14 +389,18 @@ def _make_graphed_callables( if func.training == graph_training_state: # Set the FP8 group from global amax reduction. for m in func.modules(): - if (isinstance(m, TransformerEngineBaseModule) - and FP8GlobalStateManager.is_fp8_enabled()): + if ( + isinstance(m, TransformerEngineBaseModule) + and FP8GlobalStateManager.is_fp8_enabled() + ): m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - m.fp8_meta, fp8_weights=m._get_fp8_params()) + m.fp8_meta, fp8_weights=m._get_fp8_params() + ) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) + return new_fwd forward = make_graphed_forward(func, func.training, graphed, func.forward) @@ -496,13 +495,14 @@ def make_graphed_callables( # FP8 wrapper. def wrap_autocast(block): old_forward = block.forward + def forward_func(*args, **kwargs): - with fp8_autocast(enabled=fp8_enabled, - calibrating=fp8_calibrating, - fp8_recipe=fp8_recipe, - _graph=True): + with fp8_autocast( + enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True + ): outputs = old_forward(*args, **kwargs) return outputs + block.forward = forward_func forward_funcs = [] @@ -518,16 +518,22 @@ def make_graphed_callables( # Save RNG state. if graph_safe_rng_available(): - generators = [torch.cuda.default_generators[torch.cuda.current_device()], - *get_all_rng_states().values()] + generators = [ + torch.cuda.default_generators[torch.cuda.current_device()], + *get_all_rng_states().values(), + ] original_rng_states = [state.get_state() for state in generators] else: original_rng_states = torch.cuda.get_rng_state() graphed_callables = _make_graphed_callables( - forward_funcs, sample_args, num_warmup_iters=num_warmup_iters, + forward_funcs, + sample_args, + num_warmup_iters=num_warmup_iters, allow_unused_input=allow_unused_input, - fp8_weight_caching=fp8_weight_caching, _order=_order) + fp8_weight_caching=fp8_weight_caching, + _order=_order, + ) # Ensures warmup does not affect numerics for ops such as dropout. if graph_safe_rng_available(): diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index e86fbf3b9a3afce784d5b99dd893716515b1be83..16468471628991d55cd218d783cdc504e112c215 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -22,9 +22,11 @@ if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1")) no_torch_dynamo = lambda recursive=True: lambda func: func if torch.__version__ >= "2": import torch._dynamo + if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: \ - torch._dynamo.disable(f, recursive=recursive) + no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable( + f, recursive=recursive + ) else: # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True no_torch_dynamo = lambda recursive=True: torch._dynamo.disable @@ -35,7 +37,7 @@ def set_jit_fusion_options() -> None: # flags required to enable jit fusion kernels TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - if (TORCH_MAJOR == 2 and TORCH_MINOR >= 2): + if TORCH_MAJOR == 2 and TORCH_MINOR >= 2: pass elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): # nvfuser @@ -81,27 +83,25 @@ def bgrad_dgelu_fused_( x = inp + bias tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ( - (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) - ) + 0.5 * (1 + tanh_out) + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) dgelu = ff * grad_output bgrad = dgelu.sum(dim=0) return bgrad, dgelu @jit_fuser -def dgelu_fused_( - grad_output: torch.Tensor, inp: torch.Tensor -) -> torch.Tensor: +def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: """ Dgelu fused, this is copy of bgrad_dgelu_fused_ cause jit fusion doesn't allow conditioning. """ x = inp tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ( - (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) - ) + 0.5 * (1 + tanh_out) + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) dgelu = ff * grad_output return dgelu @@ -187,19 +187,13 @@ def warmup_jit_bias_dropout_add( # Save cuda RNG state to ensure warmup does not affect reproducibility. rng_state = torch.cuda.get_rng_state() - inp = torch.rand( - (seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda" - ) - residual = torch.rand( - (seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda" - ) + inp = torch.rand((seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda") + residual = torch.rand((seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda") bias = torch.rand((hidden_size), dtype=dtype, device="cuda") dropout_rate = 0.1 # Warmup JIT fusions with the input grad_enable state of both forward # prop and recomputation - for input_grad, bias_grad, residual_grad in zip( - [False, True], [True, True], [True, True] - ): + for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): inp.requires_grad = input_grad bias.requires_grad = bias_grad residual.requires_grad = residual_grad diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index ab6455649c1ee6fb48b65da77577897feba02edc..281e3fe104bead696c2f4a6ccc1788b21d2a0e04 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -14,23 +14,23 @@ from ..export import is_in_onnx_export_mode from ..fp8 import get_fp8_te_dtype from ..utils import get_default_init_method -def _get_normalization_func(normalization: str, - fp8_output: bool, - is_grad_enabled: bool, - forward: bool): + +def _get_normalization_func( + normalization: str, fp8_output: bool, is_grad_enabled: bool, forward: bool +): fwd_normalization_funcs = { - ('LayerNorm', True, True): tex.layernorm_fwd_fp8, - ('LayerNorm', True, False): tex.layernorm_fwd_fp8_inf, - ('LayerNorm', False, True): tex.layernorm_fwd_noalloc, - ('LayerNorm', False, False): tex.layernorm_fwd_inf, - ('RMSNorm', True, True): tex.rmsnorm_fwd_fp8, - ('RMSNorm', True, False): tex.rmsnorm_fwd_fp8_inf, - ('RMSNorm', False, True): tex.rmsnorm_fwd_noalloc, - ('RMSNorm', False, False): tex.rmsnorm_fwd_inf, + ("LayerNorm", True, True): tex.layernorm_fwd_fp8, + ("LayerNorm", True, False): tex.layernorm_fwd_fp8_inf, + ("LayerNorm", False, True): tex.layernorm_fwd_noalloc, + ("LayerNorm", False, False): tex.layernorm_fwd_inf, + ("RMSNorm", True, True): tex.rmsnorm_fwd_fp8, + ("RMSNorm", True, False): tex.rmsnorm_fwd_fp8_inf, + ("RMSNorm", False, True): tex.rmsnorm_fwd_noalloc, + ("RMSNorm", False, False): tex.rmsnorm_fwd_inf, } bwd_normalization_funcs = { - 'LayerNorm': tex.layernorm_bwd, - 'RMSNorm': tex.rmsnorm_bwd, + "LayerNorm": tex.layernorm_bwd, + "RMSNorm": tex.rmsnorm_bwd, } if forward: @@ -39,21 +39,21 @@ def _get_normalization_func(normalization: str, assert is_grad_enabled, "Gradient has to be enabled to call backward normalization!" return bwd_normalization_funcs[normalization] -def _apply_normalization(inputmat:torch.Tensor, - ln_out: torch.Tensor, - ln_weight: torch.Tensor, - ln_bias: Union[torch.Tensor, None], - eps: float, - fp8_out: bool, - fp8_meta: Dict[str, Any], - normalization: str, - fwd_ln_sm_margin: int, - zero_centered_gamma: bool, - is_grad_enabled: bool): - normalization_func = _get_normalization_func(normalization, - fp8_out, - is_grad_enabled, - True) + +def _apply_normalization( + inputmat: torch.Tensor, + ln_out: torch.Tensor, + ln_weight: torch.Tensor, + ln_bias: Union[torch.Tensor, None], + eps: float, + fp8_out: bool, + fp8_meta: Dict[str, Any], + normalization: str, + fwd_ln_sm_margin: int, + zero_centered_gamma: bool, + is_grad_enabled: bool, +): + normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) if fp8_out: @@ -73,25 +73,28 @@ def _apply_normalization(inputmat:torch.Tensor, **output_kwarg, ) else: - return normalization_func( - *inputs, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - fwd_ln_sm_margin, - zero_centered_gamma, - ), None, None + return ( + normalization_func( + *inputs, + eps, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + fwd_ln_sm_margin, + zero_centered_gamma, + ), + None, + None, + ) else: if is_grad_enabled: - output = normalization_func( - *inputs, ln_out, eps, - fwd_ln_sm_margin, zero_centered_gamma - ) + output = normalization_func(*inputs, ln_out, eps, fwd_ln_sm_margin, zero_centered_gamma) else: - return normalization_func( - *inputs, eps, fwd_ln_sm_margin, zero_centered_gamma - ), None, None + return ( + normalization_func(*inputs, eps, fwd_ln_sm_margin, zero_centered_gamma), + None, + None, + ) if normalization == "RMSNorm": output = (ln_out, None, output[1]) elif normalization == "LayerNorm": @@ -132,7 +135,7 @@ class _NoopCatFunc(torch.autograd.Function): if ( len(in_shape) != num_dims or in_shape[:dim] != out_shape[:dim] - or in_shape[dim+1:] != out_shape[dim+1:] + or in_shape[dim + 1 :] != out_shape[dim + 1 :] ): raise ValueError( "Attempted to concatenate tensors with shapes " @@ -213,6 +216,7 @@ class _ParameterInitMeta: """ Stores essential metadata needed to support deferred parameter initialization. """ + init_fn: Optional[Callable] = get_default_init_method() get_rng_state_tracker: Optional[Callable] = None fp8_meta_index: Optional[int] = None diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c597ac67a848d997ef9566c87e07eb1ce0d7f07b..be3d4ce6d0c14003f902a388b92b34a48f11a2c4 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -69,7 +69,7 @@ def initialize_ub( tp_group: dist_group_type, use_fp8: bool = False, dtype: torch.dtype = torch.bfloat16, - ub_cfgs: Optional[dict] = None + ub_cfgs: Optional[dict] = None, ) -> None: """Initialize communicators for TP comm overlap using userbuffers.""" global _ub_communicators @@ -86,20 +86,25 @@ def initialize_ub( # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe layers_all_gather_overlap = [ - "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" + "qkv_fprop", + "qkv_dgrad", + "proj_dgrad", + "fc1_fprop", + "fc1_dgrad", + "fc2_dgrad", ] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] # Default overlap methods for layers methods = { - "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], - "pipeline":["proj_fprop", "fc2_fprop"], - "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], + "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], + "pipeline": ["proj_fprop", "fc2_fprop"], + "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], } # AG-RS overlap pairs of layers forming a tensor-parallel block - ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"} - rs_ag_pairs = {v : k for k, v in ag_rs_pairs.items()} + ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} + rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] @@ -126,13 +131,13 @@ def initialize_ub( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." - if method == 'bulk': + if method == "bulk": warnings.warn( f"At {name}, atoimic GEMM not is supported for a bulk overlap." "Defaulting to `atomic_gemm=False`." ) atomic_gemm = 0 - if not is_reduce_scatter and method == 'pipeline': + if not is_reduce_scatter and method == "pipeline": raise ValueError( f"At {name}, `pipeline` overlap method is not supported for AllGather." ) @@ -156,46 +161,45 @@ def initialize_ub( assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message sample_buffer = torch.empty( - shape, - dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype, - device='cuda') - if method == 'ring_exchange': + shape, dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype, device="cuda" + ) + if method == "ring_exchange": ub_obj = tex.UbufP2PCommOverlap( - sample_buffer, # Sample userbuffer - rank_id, # Rank id - world_size, # World size - tp_id, # TP id - tp_size, # TP size - num_sm, # Number of communication SMs - cga_size, # CGA cluster size - set_sm_margin, # Set SM margin - aggregate, # Aggregate 2X GEMM chunks - _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - is_reduce_scatter, # overlap with reduce scatter - atomic_gemm, # use a single GEMM with atomic-counters - torch.Tensor(), # empty tensor to pass to counters - ) + sample_buffer, # Sample userbuffer + rank_id, # Rank id + world_size, # World size + tp_id, # TP id + tp_size, # TP size + num_sm, # Number of communication SMs + cga_size, # CGA cluster size + set_sm_margin, # Set SM margin + aggregate, # Aggregate 2X GEMM chunks + _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + is_reduce_scatter, # overlap with reduce scatter + atomic_gemm, # use a single GEMM with atomic-counters + torch.Tensor(), # empty tensor to pass to counters + ) else: ub_obj = tex.UbufCommOverlap( - sample_buffer, # Sample userbuffer - rank_id, # Rank id - world_size, # World size - tp_id, # TP id - tp_size, # TP size - num_sm, # Number of communication SMs - cga_size, # CGA cluster size - num_splits, # Number of communication splits - set_sm_margin, # Set SM margin - _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - atomic_gemm, # use a single GEMM with atomic-counters - torch.Tensor(), # empty tensor to pass to counters - ) + sample_buffer, # Sample userbuffer + rank_id, # Rank id + world_size, # World size + tp_id, # TP id + tp_size, # TP size + num_sm, # Number of communication SMs + cga_size, # CGA cluster size + num_splits, # Number of communication splits + set_sm_margin, # Set SM margin + _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + atomic_gemm, # use a single GEMM with atomic-counters + torch.Tensor(), # empty tensor to pass to counters + ) _ub_communicators[name] = ub_obj def alloc_copy_allgather_callback(local_data: torch.Tensor, group: str) -> torch.Tensor: pg = None if group == "world" else tp_group global_size = local_data.numel() * torch.distributed.get_world_size(pg) - global_data = torch.zeros(global_size, dtype=local_data.dtype, device='cuda') + global_data = torch.zeros(global_size, dtype=local_data.dtype, device="cuda") torch.distributed.all_gather_into_tensor(global_data, local_data.cuda(), group=pg) return global_data.cpu() @@ -206,21 +210,17 @@ def initialize_ub( def free_callback(data: torch.Tensor) -> None: data.data = torch.Tensor() - tex.set_ubuf_bootstrap_callbacks( - alloc_copy_allgather_callback, - barrier_callback, - free_callback - ) + tex.set_ubuf_bootstrap_callbacks(alloc_copy_allgather_callback, barrier_callback, free_callback) if ub_cfgs is not None: for name in dgrad_reduce_scatter_overlap: - if name in ub_cfgs and 'method' in ub_cfgs[name] and ub_cfgs[name]['method'] != 'bulk': - wgrad_name = name.replace('dgrad','wgrad') + if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk": + wgrad_name = name.replace("dgrad", "wgrad") assert wgrad_name not in ub_cfgs layers_reduce_scatter_overlap.remove(wgrad_name) layers_reduce_scatter_overlap.append(name) - for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]): + for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: if ub_cfgs is not None and name in ub_cfgs: ub_cfg = ub_cfgs[name] method = ub_cfg.get("method", get_method(name)) @@ -232,8 +232,9 @@ def initialize_ub( atomic_gemm = ub_cfg.get("atomic_gemm", 0) is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0 # Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter - fp8_buf = ((name in layers_all_gather_overlap) or - (ub_cfg.get("fp8_buf", False) and name in methods["pipeline"])) + fp8_buf = (name in layers_all_gather_overlap) or ( + ub_cfg.get("fp8_buf", False) and name in methods["pipeline"] + ) add_ub( name, method, @@ -264,6 +265,7 @@ def get_ub(name: str): assert name in _ub_communicators, f"UB for {name} is not registered." return _ub_communicators[name] + def destroy_ub(): """Destroy all allocated userbuffer communicators.""" global _ub_communicators @@ -315,7 +317,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): continue if length < curr_len: self.fp8_meta[meta_key].amax_history = ( - self.fp8_meta[meta_key].amax_history[: length].clone()) + self.fp8_meta[meta_key].amax_history[:length].clone() + ) elif length > curr_len: extra_rows = length - curr_len self.fp8_meta[meta_key].amax_history = F.pad( @@ -324,17 +327,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): # Update the global buffers with new amax and history pointers. if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: - fwd_pos, fwd_key, bwd_pos, bwd_key = ( - self.fp8_meta[FP8GlobalStateManager.get_buffer_info()]) + fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[ + FP8GlobalStateManager.get_buffer_info() + ] for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): if buffer_key in FP8GlobalStateManager.global_amax_buffer: assert ( buffer_key in FP8GlobalStateManager.global_amax_history_buffer ), "TE internal error during amax history change." - FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = ( - self.fp8_meta[meta_key].amax_history[0]) + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[ + meta_key + ].amax_history[0] FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( - self.fp8_meta[meta_key].amax_history) + self.fp8_meta[meta_key].amax_history + ) def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" @@ -347,9 +353,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd - num_fp8_tensors = ( - self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 - ) + num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones( @@ -387,19 +391,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None: """Reset scales and amaxes.""" + def reset(key): if key in self.fp8_meta: if fp8_meta_tensors is None: self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) self.fp8_meta[key].scale_inv.copy_( - torch.ones_like(self.fp8_meta[key].scale_inv)) + torch.ones_like(self.fp8_meta[key].scale_inv) + ) self.fp8_meta[key].amax_history.copy_( - torch.zeros_like(self.fp8_meta[key].amax_history)) + torch.zeros_like(self.fp8_meta[key].amax_history) + ) else: assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1]) self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2]) + with torch.no_grad(): reset("scaling_fwd") reset("scaling_bwd") @@ -443,7 +451,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): state = pickle.loads(state.detach().cpu().numpy().tobytes()) elif isinstance(state, io.BytesIO): state.seek(0) - state = torch.load(state, map_location='cuda') + state = torch.load(state, map_location="cuda") else: raise RuntimeError("Unsupported checkpoint format.") @@ -529,8 +537,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): if self.fp8 or self.fp8_calibration: # FP8 init has already been run and recipe is the same, don't do anything. - if (self.fp8_initialized - and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]): + if ( + self.fp8_initialized + and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] + ): return # Set FP8, recipe, and other FP8 metadata @@ -577,20 +587,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): self.init_fp8_metadata(num_gemms=num_gemms) if self.fp8 and self.sequence_parallel: - assert self.fp8_meta["recipe"].reduce_amax, \ - "Amax reduction across tensor parallel group is " \ - "necessary when using sequence parallelism with FP8." + assert self.fp8_meta["recipe"].reduce_amax, ( + "Amax reduction across tensor parallel group is " + "necessary when using sequence parallelism with FP8." + ) if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.fp8_meta, fp8_weights=self._get_fp8_params()) + self.fp8_meta, fp8_weights=self._get_fp8_params() + ) # Activation recomputation is used and this is the first forward phase. - if ( - self.fp8 - and self.training - and is_fp8_activation_recompute_enabled() - ): + if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): @@ -643,23 +651,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): if not ctx.fp8: if gather_grad_output: if not ctx.ub_overlap_ag: - grad_output_mat, _ = gather_along_first_dim( - grad_output_mat, ctx.tp_group - ) + grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) else: ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True) grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1) return grad_output_mat, None, None, None - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) # FP8 case with non-FP8 wgrad - if ( - gather_grad_output - and ctx.fp8_meta["recipe"].override_linear_precision.wgrad - ): + if gather_grad_output and ctx.fp8_meta["recipe"].override_linear_precision.wgrad: assert ( not ctx.ub_overlap_ag ), "override_linear_precision.wgrad not supported with UB AG overlap" @@ -753,8 +754,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): for name, param in self.named_parameters(recurse=False): # Ensure parameter is on a real device - if param.device == torch.device('meta'): - param = torch.empty_like(param, device='cuda') + if param.device == torch.device("meta"): + param = torch.empty_like(param, device="cuda") # Initialize the parameter values on device init_fn = self.param_init_meta[name].init_fn @@ -837,17 +838,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # for models initialized with Fp8 primary weights. - if (not isinstance(out, Float8Tensor) and - fsdp_group is not None and - out._data.shape != tensor.data.shape): + if ( + not isinstance(out, Float8Tensor) + and fsdp_group is not None + and out._data.shape != tensor.data.shape + ): _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) if out is None: - if ( - tensor is None - or fp8_meta_forward is None - or fp8_meta_index is None - ): + if tensor is None or fp8_meta_forward is None or fp8_meta_index is None: raise ValueError( "tensor, fp8_meta_forward, and fp8_meta_index kwargs " "must be provided to construct FP8 workspace" @@ -856,11 +855,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): self.fp8_meta["recipe"], fprop_tensor=fp8_meta_forward, ) - scale_inv = torch.empty( - [1], - dtype=torch.float32, - device=tensor.device - ) + scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) out = Float8Tensor( data=torch.empty_like(tensor, dtype=torch.uint8), fp8_meta=self.fp8_meta, @@ -880,9 +875,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): update_workspace = True if update_workspace: if tensor is None: - raise ValueError( - "tensor kwarg must be provided to update FP8 workspace" - ) + raise ValueError("tensor kwarg must be provided to update FP8 workspace") if with_transpose: out.cast_transpose_( tensor, @@ -913,8 +906,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): return out - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): """ This function loads tensors and extra state including fp8 metadata. This metadata is essential for copying fp8 tensors, as the copy_ function @@ -929,5 +923,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 2b2f450bf24db9a25519b1b4e59f86c1959fa9cc..ec33ad2033b3deee7111765bc8174df543a9e05b 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -15,7 +15,7 @@ import transformer_engine_torch as tex from .base import TransformerEngineBaseModule from ..cpp_extensions import ( layernorm_fwd_inf, - ) +) from ..jit import no_torch_dynamo from ..utils import cast_if_needed @@ -51,27 +51,30 @@ class _LayerNorm(torch.autograd.Function): ln_bias = cast_if_needed(ln_bias, activation_dtype) if is_grad_enabled: - ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, - ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma) + ln_out, mu, rsigma = tex.layernorm_fwd( + inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma + ) ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) ctx.inp_shape = inp.shape ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma else: - ln_out, mu, rsigma = layernorm_fwd_inf(inputmat, ln_weight, - ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma), None, None + ln_out, mu, rsigma = ( + layernorm_fwd_inf( + inputmat, ln_weight, ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma + ), + None, + None, + ) return ln_out.view_as(inp) @staticmethod - def backward( - ctx, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: inputmat, ln_weight, mu, rsigma = ctx.saved_tensors grad_output = grad_output.contiguous() d_ln_out = grad_output.view(inputmat.shape) dxmat, dgamma, dbeta = tex.layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None @@ -141,7 +144,7 @@ class LayerNorm(torch.nn.Module): ) self.sequence_parallel = sequence_parallel - self.reset_parameters(defer_init=(device == 'meta')) + self.reset_parameters(defer_init=(device == "meta")) # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN @@ -154,10 +157,10 @@ class LayerNorm(torch.nn.Module): def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( - ("This method will be deprecated in an upcoming release. " - "Update your code to use LayerNorm.reset_parameters() instead."), + "This method will be deprecated in an upcoming release. " + "Update your code to use LayerNorm.reset_parameters() instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if not self.zero_centered_gamma: init.ones_(self.weight) @@ -170,13 +173,13 @@ class LayerNorm(torch.nn.Module): if defer_init: return - if self.weight.device == torch.device('meta'): - self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device='cuda')) + if self.weight.device == torch.device("meta"): + self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda")) setattr(self.weight, "sequence_parallel", self.sequence_parallel) init.constant_(self.weight, float(not self.zero_centered_gamma)) - if self.bias.device == torch.device('meta'): - self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device='cuda')) + if self.bias.device == torch.device("meta"): + self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device="cuda")) setattr(self.bias, "sequence_parallel", self.sequence_parallel) init.zeros_(self.bias) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2629c812aefd91b45a045fcc3f8b52c586a26e35..b240d960d8f856fef2670c40258ddddf9b8a53e3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -46,6 +46,7 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor + _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) __all__ = ["LayerNormLinear"] @@ -114,7 +115,7 @@ class _LayerNormLinear(torch.autograd.Function): if ub_overlap_ag: dim_size = list(inputmat.size()) dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ub_name+"_fprop") + ub_obj_lnout = get_ub(ub_name + "_fprop") if return_layernorm_output: # First prepare LN output in higher precision, # which will be later copied to a FP8 UB @@ -127,17 +128,19 @@ class _LayerNormLinear(torch.autograd.Function): fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - ln_out, mu, rsigma = _apply_normalization(inputmat, - ln_out, - ln_weight, - ln_bias, - eps, - fp8 and not return_layernorm_output, - fp8_meta, - normalization, - fwd_ln_sm_margin, - zero_centered_gamma, - is_grad_enabled) + ln_out, mu, rsigma = _apply_normalization( + inputmat, + ln_out, + ln_weight, + ln_bias, + eps, + fp8 and not return_layernorm_output, + fp8_meta, + normalization, + fwd_ln_sm_margin, + zero_centered_gamma, + is_grad_enabled, + ) # Column Parallel Linear ln_out_gathered = False @@ -168,7 +171,8 @@ class _LayerNormLinear(torch.autograd.Function): fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, - out=ln_out_fp8) + out=ln_out_fp8, + ) ln_out = ln_out_fp8 else: ln_out_total = tex.cast_to_fp8( @@ -187,13 +191,9 @@ class _LayerNormLinear(torch.autograd.Function): if fp8: if _NVTE_DEBUG: - print('[LayerNormLinear]: using FP8 forward') + print("[LayerNormLinear]: using FP8 forward") - bias_dtype = ( - torch.bfloat16 - if activation_dtype == torch.float32 - else activation_dtype - ) + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias # Use FP8 weights @@ -207,10 +207,15 @@ class _LayerNormLinear(torch.autograd.Function): tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], fp8_dtype_forward, - torch.uint8) + torch.uint8, + ) else: out_index, meta_tensor, output_te_dtype, output_dtype = ( - None, None, None, activation_dtype) + None, + None, + None, + activation_dtype, + ) out, _ = tex.fp8_gemm( weight_fp8._data, weight_fp8._scale_inv, @@ -233,7 +238,8 @@ class _LayerNormLinear(torch.autograd.Function): D_dtype=output_te_dtype, ) if output_dtype == torch.uint8: - out = Float8Tensor(data=out, + out = Float8Tensor( + data=out, fp8_meta=fp8_meta, fp8_meta_forward=True, fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, @@ -242,7 +248,7 @@ class _LayerNormLinear(torch.autograd.Function): ) else: if _NVTE_DEBUG: - print('[LayerNormLinear]: using non-FP8 forward') + print("[LayerNormLinear]: using non-FP8 forward") # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) @@ -251,12 +257,14 @@ class _LayerNormLinear(torch.autograd.Function): if fp8_calibration: # amax of input amin, amax = ln_out_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \ - torch.max(-amin, amax).float() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( + -amin, amax + ).float() # amax of weight amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ - torch.max(-amin, amax).float() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( + -amin, amax + ).float() out, _, _ = tex.gemm( weight, @@ -323,8 +331,9 @@ class _LayerNormLinear(torch.autograd.Function): ctx.tp_group = tp_group ctx.tp_size = tp_size ctx.return_layernorm_output = return_layernorm_output - ctx.return_layernorm_output_gathered = return_layernorm_output_gathered \ - and ln_out_gathered + ctx.return_layernorm_output_gathered = ( + return_layernorm_output_gathered and ln_out_gathered + ) ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad @@ -336,8 +345,9 @@ class _LayerNormLinear(torch.autograd.Function): ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): ctx.reduce_and_update_bwd_fp8_tensors = ( - ctx.reduce_and_update_bwd_fp8_tensors or - FP8GlobalStateManager.is_first_fp8_module()) + ctx.reduce_and_update_bwd_fp8_tensors + or FP8GlobalStateManager.is_first_fp8_module() + ) # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -356,14 +366,14 @@ class _LayerNormLinear(torch.autograd.Function): return out, ln_out_return.view_as(inp) return out - @staticmethod def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: if isinstance(grad_outputs[0], Float8Tensor): - ctx.fp8_meta["scaling_bwd"].scale_inv[ - tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[0]._scale_inv + ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ + 0 + ]._scale_inv with torch.cuda.nvtx.range("_LayerNormLinear_backward"): ( @@ -407,7 +417,7 @@ class _LayerNormLinear(torch.autograd.Function): if ctx.ub_bulk_dgrad: dim_size = list(ln_out.size()) dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ctx.ub_name+"_dgrad") + ub_obj_lnout = get_ub(ctx.ub_name + "_dgrad") ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) ( grad_output, @@ -425,13 +435,13 @@ class _LayerNormLinear(torch.autograd.Function): # Column Parallel Linear # Overlap input AG with dgrad - if (weight.requires_grad + if ( + weight.requires_grad and (not ctx.ub_bulk_dgrad) and ctx.parallel_mode == "column" - and ctx.sequence_parallel): - ln_out_total, handle = gather_along_first_dim( - ln_out, ctx.tp_group, async_op=True - ) + and ctx.sequence_parallel + ): + ln_out_total, handle = gather_along_first_dim(ln_out, ctx.tp_group, async_op=True) else: ln_out_total = ln_out handle = None @@ -443,15 +453,14 @@ class _LayerNormLinear(torch.autograd.Function): else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - dgrad_size = list(grad_output.size()) dgrad_size[1] = weight.size(1) - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad") - dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + if ctx.ub_bulk_wgrad: # allocate dgrad output + ub_obj_dgrad = get_ub(ctx.ub_name + "_wgrad") + dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub(ctx.ub_name+"_dgrad") - dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + ub_obj_dgrad = get_ub(ctx.ub_name + "_dgrad") + dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) @@ -463,10 +472,11 @@ class _LayerNormLinear(torch.autograd.Function): dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = weight.size(1) rs_out = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=grad_output.device) + dim_size, dtype=ctx.activation_dtype, device=grad_output.device + ) if ub_obj_dgrad.is_p2p_overlap(): if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: @@ -481,16 +491,16 @@ class _LayerNormLinear(torch.autograd.Function): if ctx.fp8: if _NVTE_DEBUG: - print('[LayerNormLinear]: using FP8 backward') + print("[LayerNormLinear]: using FP8 backward") - fp8_dtype_forward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=True - ) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) out_index, meta_tensor, out_te_type, out_type = ( - None, None, None, ctx.activation_dtype) + None, + None, + None, + ctx.activation_dtype, + ) if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): out_index = tex.FP8BwdTensors.GRAD_INPUT1 meta_tensor = ctx.fp8_meta["scaling_bwd"] @@ -504,8 +514,11 @@ class _LayerNormLinear(torch.autograd.Function): weight_fp8._scale_inv, 0, weight_fp8._fp8_dtype, - grad_output_c._data - if isinstance(grad_output_c, Float8Tensor) else grad_output_c, + ( + grad_output_c._data + if isinstance(grad_output_c, Float8Tensor) + else grad_output_c + ), ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, @@ -517,13 +530,13 @@ class _LayerNormLinear(torch.autograd.Function): ub=ub_obj, extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, out_index=out_index, - fp8_meta_tensor = meta_tensor, - D_dtype = out_te_type, + fp8_meta_tensor=meta_tensor, + D_dtype=out_te_type, ) clear_tensor_data(grad_output_c) else: if _NVTE_DEBUG: - print('[LayerNormLinear]: using non-FP8 backward') + print("[LayerNormLinear]: using non-FP8 backward") # DGRAD: Evaluated unconditionally to feed into Linear backward _, _, _ = tex.gemm( @@ -560,9 +573,10 @@ class _LayerNormLinear(torch.autograd.Function): extra_output_tensor = None if ctx.ub_bulk_wgrad: if ub_obj_dgrad.is_fp8_ubuf(): - dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output + dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output extra_output_tensor = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=dgrad.device) + dim_size, dtype=ctx.activation_dtype, device=dgrad.device + ) dgrad = extra_output_tensor else: dgrad = ub_obj_dgrad.get_ubuf_output(0) @@ -573,8 +587,11 @@ class _LayerNormLinear(torch.autograd.Function): fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, - grad_output_t._data - if isinstance(grad_output_t, Float8Tensor) else grad_output_t, + ( + grad_output_t._data + if isinstance(grad_output_t, Float8Tensor) + else grad_output_t + ), ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, @@ -583,10 +600,11 @@ class _LayerNormLinear(torch.autograd.Function): accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS - if ctx.ub_bulk_wgrad else None, + ub_algo=( + tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor + extra_output_tensor=extra_output_tensor, ) clear_tensor_data(ln_out_total_t, grad_output_t) else: @@ -606,10 +624,11 @@ class _LayerNormLinear(torch.autograd.Function): grad=True, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS - if ctx.ub_bulk_wgrad else None, + ub_algo=( + tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor + extra_output_tensor=extra_output_tensor, ) clear_tensor_data(ln_out_total_c) else: @@ -625,17 +644,19 @@ class _LayerNormLinear(torch.autograd.Function): accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) clear_tensor_data(ln_out_total) if ctx.ub_bulk_wgrad: - dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output + dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output # Column Parallel Linear - if ((not ctx.ub_bulk_wgrad) + if ( + (not ctx.ub_bulk_wgrad) and ctx.parallel_mode == "column" and ctx.tensor_parallel - and handle is not None): + and handle is not None + ): handle.wait() # LayerNorm gradient @@ -650,13 +671,22 @@ class _LayerNormLinear(torch.autograd.Function): if ctx.normalization == "LayerNorm": dgrad, dgamma, dbeta = tex.layernorm_bwd( - dgrad, inputmat, mu, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + dgrad, + inputmat, + mu, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, ) elif ctx.normalization == "RMSNorm": dgrad, dgamma = tex.rmsnorm_bwd( - dgrad, inputmat, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + dgrad, + inputmat, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, ) dbeta = None clear_tensor_data(mu) @@ -667,20 +697,22 @@ class _LayerNormLinear(torch.autograd.Function): if weight.requires_grad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'): + if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): weight.grad_added_to_main_grad = True - if getattr(weight, 'zero_out_wgrad', False): - wgrad = torch.zeros(weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False - ) + if getattr(weight, "zero_out_wgrad", False): + wgrad = torch.zeros( + weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) else: - wgrad = torch.empty(weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False - ) + wgrad = torch.empty( + weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: wgrad = None else: @@ -828,7 +860,7 @@ class LayerNormLinear(TransformerEngineBaseModule): get_rng_state_tracker: Optional[Callable] = None, init_method: Optional[Callable] = None, bias: bool = True, - normalization: str = 'LayerNorm', + normalization: str = "LayerNorm", return_bias: bool = False, params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, @@ -850,7 +882,7 @@ class LayerNormLinear(TransformerEngineBaseModule): self.out_features = out_features self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.normalization = normalization - assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!" + assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" self.use_bias = bias self.return_bias = return_bias self.apply_bias = self.use_bias and not return_bias @@ -893,14 +925,18 @@ class LayerNormLinear(TransformerEngineBaseModule): layer_norm_weight = torch.nn.Parameter( torch.empty(in_features, device=device, dtype=params_dtype) ) - self.register_parameter('layer_norm_weight', layer_norm_weight, - init_fn=init_method_constant(float(not self.zero_centered_gamma))) + self.register_parameter( + "layer_norm_weight", + layer_norm_weight, + init_fn=init_method_constant(float(not self.zero_centered_gamma)), + ) if self.normalization != "RMSNorm": layer_norm_bias = torch.nn.Parameter( torch.empty(in_features, device=device, dtype=params_dtype) ) - self.register_parameter('layer_norm_bias', layer_norm_bias, - init_fn=init_method_constant(0.0)) + self.register_parameter( + "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) + ) else: self.layer_norm_bias = None @@ -980,10 +1016,7 @@ class LayerNormLinear(TransformerEngineBaseModule): # Check if parameters are subviews of buffers is_subview = (split_start, split_end) != (0, self.out_features) if is_subview and with_fp8_params: - raise RuntimeError( - "Splitting Float8Tensor into multiple params " - "is not supported" - ) + raise RuntimeError("Splitting Float8Tensor into multiple params is not supported") # Construct weight parameter self.register_parameter( @@ -1014,7 +1047,7 @@ class LayerNormLinear(TransformerEngineBaseModule): if with_fp8_params: self.init_fp8_metadata() - self.reset_parameters(defer_init=(device == 'meta')) + self.reset_parameters(defer_init=(device == "meta")) # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM @@ -1034,10 +1067,10 @@ class LayerNormLinear(TransformerEngineBaseModule): def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( - ("This method will be deprecated in an upcoming release. " - "Update your code to use LayerNormLinear.reset_parameters() instead."), + "This method will be deprecated in an upcoming release. " + "Update your code to use LayerNormLinear.reset_parameters() instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if not self.zero_centered_gamma: init.ones_(self.layer_norm_weight) @@ -1112,8 +1145,7 @@ class LayerNormLinear(TransformerEngineBaseModule): if self.fp8: if len(unfused_weights) != 1: raise RuntimeError( - "Splitting Float8Tensor into multiple params " - "is not supported" + "Splitting Float8Tensor into multiple params is not supported" ) else: unfused_weights = [w.from_float8() for w in unfused_weights] @@ -1140,8 +1172,7 @@ class LayerNormLinear(TransformerEngineBaseModule): update_transpose_cache = with_transpose if update_transpose_cache: update_transpose_cache = ( - is_first_microbatch - or skip_fp8_weight_update is not None + is_first_microbatch or skip_fp8_weight_update is not None ) if update_transpose_cache: weight_tensor.transpose_2d( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index fe67b5877e71bac53bf73e24cf5524e0fb9b153f..8b971e186bdadb3585e6d21c5da6619da32ca651 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -61,13 +61,13 @@ __all__ = ["LayerNormMLP"] def _act_func(activation: str): funcs = { - 'gelu': (tex.gelu, tex.dgelu), - 'relu': (tex.relu, tex.drelu), - 'geglu': (tex.geglu, tex.dgeglu), - 'reglu': (tex.reglu, tex.dreglu), - 'swiglu': (tex.swiglu, tex.dswiglu), - 'qgelu': (tex.qgelu, tex.dqgelu), - 'srelu': (tex.srelu, tex.dsrelu), + "gelu": (tex.gelu, tex.dgelu), + "relu": (tex.relu, tex.drelu), + "geglu": (tex.geglu, tex.dgeglu), + "reglu": (tex.reglu, tex.dreglu), + "swiglu": (tex.swiglu, tex.dswiglu), + "qgelu": (tex.qgelu, tex.dqgelu), + "srelu": (tex.srelu, tex.dsrelu), } if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") @@ -154,17 +154,19 @@ class _LayerNormMLP(torch.autograd.Function): fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - ln_out, mu, rsigma = _apply_normalization(inputmat, - ln_out, - ln_weight, - ln_bias, - eps, - fp8 and not return_layernorm_output, - fp8_meta, - normalization, - fwd_ln_sm_margin, - zero_centered_gamma, - is_grad_enabled) + ln_out, mu, rsigma = _apply_normalization( + inputmat, + ln_out, + ln_weight, + ln_bias, + eps, + fp8 and not return_layernorm_output, + fp8_meta, + normalization, + fwd_ln_sm_margin, + zero_centered_gamma, + is_grad_enabled, + ) # Column Parallel Linear ln_out_gathered = False @@ -210,11 +212,7 @@ class _LayerNormMLP(torch.autograd.Function): ln_out = ln_out_total if fp8: - bias_dtype = ( - torch.bfloat16 - if activation_dtype == torch.float32 - else activation_dtype - ) + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias @@ -277,7 +275,11 @@ class _LayerNormMLP(torch.autograd.Function): clear_tensor_data(fc1_out) fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( - None, None, None, activation_dtype) + None, + None, + None, + activation_dtype, + ) if ub_overlap_rs: ub_obj_fc2out = get_ub("fc2_fprop") fc2_out = ub_obj_fc2out.get_ubuf_output(1) @@ -326,8 +328,8 @@ class _LayerNormMLP(torch.autograd.Function): ub=ub_obj_fc2out if ub_overlap_rs else None, extra_output_tensor=rs_out if ub_overlap_rs else None, out_index=fc2_out_index, - fp8_meta_tensor = fc2_meta_tensor, - D_dtype = fc2_te_type, + fp8_meta_tensor=fc2_meta_tensor, + D_dtype=fc2_te_type, ) if not is_grad_enabled: clear_tensor_data(gelu_out) @@ -335,22 +337,20 @@ class _LayerNormMLP(torch.autograd.Function): # Cast for native AMP fc1_weight = cast_if_needed(fc1_weight, activation_dtype) fc2_weight = cast_if_needed(fc2_weight, activation_dtype) - fc1_bias = ( - cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias - ) - fc2_bias = ( - cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias - ) + fc1_bias = cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias + fc2_bias = cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias if fp8_calibration: # amax of fc1 input amin, amax = ln_out_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \ - torch.max(-amin, amax).float() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( + -amin, amax + ).float() # amax of fc1 weight amin, amax = fc1_weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ - torch.max(-amin, amax).float() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( + -amin, amax + ).float() fc1_outputs = tex.gemm( fc1_weight, @@ -359,7 +359,7 @@ class _LayerNormMLP(torch.autograd.Function): get_workspace(), bias=fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, - gelu=not bias_gelu_nvfusion and (activation == 'gelu'), + gelu=not bias_gelu_nvfusion and (activation == "gelu"), ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, @@ -371,26 +371,27 @@ class _LayerNormMLP(torch.autograd.Function): fc1_out, _, _ = fc1_outputs gelu_out = bias_gelu_fused(fc1_out, fc1_bias) else: - if activation == 'gelu': + if activation == "gelu": gelu_out, _, fc1_out = fc1_outputs else: fc1_out, _, _ = fc1_outputs - gelu_out = activation_func(fc1_out, - None, - tex.FP8FwdTensors.GEMM2_INPUT, - TE_DType[fc1_out.dtype]) + gelu_out = activation_func( + fc1_out, None, tex.FP8FwdTensors.GEMM2_INPUT, TE_DType[fc1_out.dtype] + ) if not is_grad_enabled: clear_tensor_data(fc1_out) if fp8_calibration: # amax of fc2 input amin, amax = gelu_out.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \ - torch.max(-amin, amax).float() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = torch.max( + -amin, amax + ).float() # amax of fc2 weight amin, amax = fc2_weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \ - torch.max(-amin, amax).float() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = torch.max( + -amin, amax + ).float() if ub_overlap_rs: ub_obj_fc2out = get_ub("fc2_fprop") @@ -493,8 +494,9 @@ class _LayerNormMLP(torch.autograd.Function): ctx.tp_size = tp_size ctx.bias_gelu_nvfusion = bias_gelu_nvfusion ctx.return_layernorm_output = return_layernorm_output - ctx.return_layernorm_output_gathered = return_layernorm_output_gathered \ - and ln_out_gathered + ctx.return_layernorm_output_gathered = ( + return_layernorm_output_gathered and ln_out_gathered + ) ctx.set_parallel_mode = set_parallel_mode ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma @@ -506,7 +508,8 @@ class _LayerNormMLP(torch.autograd.Function): ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad( - inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias): + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + ): ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() # Row Parallel Linear @@ -528,7 +531,6 @@ class _LayerNormMLP(torch.autograd.Function): return fc2_out, ln_out_return.view_as(inp) return fc2_out - @staticmethod def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] @@ -605,15 +607,13 @@ class _LayerNormMLP(torch.autograd.Function): else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P - ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess + ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ( grad_output, grad_output_c, grad_output_t, fc2_bias_grad, - ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], True - ) + ) = TransformerEngineBaseModule.grad_output_preprocess(ctx, grad_outputs[0], True) if ctx.ub_bulk_wgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) @@ -621,13 +621,13 @@ class _LayerNormMLP(torch.autograd.Function): ctx.ub_bulk_wgrad = False # Column Parallel Linear # Overlap input AG with dgrad - if (fc1_weight.requires_grad + if ( + fc1_weight.requires_grad and (not ctx.ub_bulk_dgrad) and ctx.set_parallel_mode - and ctx.sequence_parallel): - ln_out_total, handle = gather_along_first_dim( - ln_out, ctx.tp_group, async_op=True - ) + and ctx.sequence_parallel + ): + ln_out_total, handle = gather_along_first_dim(ln_out, ctx.tp_group, async_op=True) else: ln_out_total = ln_out handle = None @@ -640,12 +640,8 @@ class _LayerNormMLP(torch.autograd.Function): accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=True - ) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) # FC2 DGRAD; Unconditional fc2_dgrad, _ = tex.fp8_gemm( @@ -684,14 +680,12 @@ class _LayerNormMLP(torch.autograd.Function): ctx.activation_dtype, get_workspace(), accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad - if ctx.fuse_wgrad_accumulation - else None, + out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ) clear_tensor_data(gelu_out_t, grad_output_t) - if ctx.activation == 'gelu': + if ctx.activation == "gelu": fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused( fc2_dgrad, fc1_out, @@ -700,8 +694,7 @@ class _LayerNormMLP(torch.autograd.Function): fp8_dtype_backward, ) else: - dgelu = activation_func(fc2_dgrad, fc1_out, - TE_DType[fc2_dgrad.dtype]) + dgelu = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_fused( dgelu, ctx.fp8_meta["scaling_bwd"], @@ -728,20 +721,18 @@ class _LayerNormMLP(torch.autograd.Function): grad=True, use_bias=False, accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad - if ctx.fuse_wgrad_accumulation - else None, + out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) clear_tensor_data(gelu_out_c) - if ctx.activation == 'gelu': + if ctx.activation == "gelu": fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( fc2_dgrad, fc1_out, fc1_bias ) else: - dgelu_no_fp8 = activation_func(fc2_dgrad, - fc1_out, - TE_DType[fc2_dgrad.dtype]) + dgelu_no_fp8 = activation_func( + fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype] + ) fc1_bias_grad = dgelu_no_fp8.sum(dim=0) clear_tensor_data(fc1_out) @@ -754,16 +745,20 @@ class _LayerNormMLP(torch.autograd.Function): dgelu_t = None out_index, meta_tensor, out_te_type, out_type = ( - None, None, None, ctx.activation_dtype) + None, + None, + None, + ctx.activation_dtype, + ) fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size[1] = fc1_weight.size(1) # Get/alloc fc1_dgrad - if ctx.ub_bulk_wgrad: # allocate dgrad output + if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("fc1_wgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output elif ctx.ub_overlap_rs_dgrad: ub_obj_dgrad = get_ub("fc1_dgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device @@ -788,7 +783,7 @@ class _LayerNormMLP(torch.autograd.Function): rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) if ub_obj_dgrad.is_p2p_overlap(): if ub_obj_dgrad.is_atomic_gemm(): - ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: @@ -818,8 +813,8 @@ class _LayerNormMLP(torch.autograd.Function): ub=ub_obj, extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, out_index=out_index, - fp8_meta_tensor = meta_tensor, - D_dtype = out_te_type, + fp8_meta_tensor=meta_tensor, + D_dtype=out_te_type, ) else: # FC2 DGRAD; Unconditional @@ -829,11 +824,12 @@ class _LayerNormMLP(torch.autograd.Function): ctx.activation_dtype, get_workspace(), layout="NN", - gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == 'gelu'), + gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == "gelu"), grad=True, gelu_input=fc1_out, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \ - if ctx.ub_overlap_ag else None, + ub_algo=( + tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None + ), ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) @@ -852,13 +848,11 @@ class _LayerNormMLP(torch.autograd.Function): ) clear_tensor_data(gelu_out) - if ctx.bias_gelu_nvfusion and ctx.activation == 'gelu': + if ctx.bias_gelu_nvfusion and ctx.activation == "gelu": fc1_bias_grad, fc2_dgrad = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) else: - if ctx.activation != 'gelu': - fc2_dgrad = activation_func(fc2_dgrad, - fc1_out, - TE_DType[fc2_dgrad.dtype]) + if ctx.activation != "gelu": + fc2_dgrad = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) # For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM # and will not be calculated in case wgrad is not required. @@ -871,12 +865,12 @@ class _LayerNormMLP(torch.autograd.Function): fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size[1] = fc1_weight.size(1) - if ctx.ub_bulk_wgrad: # allocate dgrad output + if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("fc1_wgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output elif ctx.ub_overlap_rs_dgrad: ub_obj_dgrad = get_ub("fc1_dgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device @@ -934,9 +928,10 @@ class _LayerNormMLP(torch.autograd.Function): extra_output_tensor = None if ctx.ub_bulk_wgrad: if ub_obj_dgrad.is_fp8_ubuf(): - dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output + dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output extra_output_tensor = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device) + dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device + ) fc1_dgrad = extra_output_tensor else: fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) @@ -954,12 +949,11 @@ class _LayerNormMLP(torch.autograd.Function): ctx.activation_dtype, get_workspace(), accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad - if ctx.fuse_wgrad_accumulation - else None, + out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS - if ctx.ub_bulk_wgrad else None, + ub_algo=( + tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, ) @@ -980,11 +974,10 @@ class _LayerNormMLP(torch.autograd.Function): layout="NT", grad=True, accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad - if ctx.fuse_wgrad_accumulation - else None, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS - if ctx.ub_bulk_wgrad else None, + out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub_algo=( + tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, ) @@ -1002,7 +995,7 @@ class _LayerNormMLP(torch.autograd.Function): accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) clear_tensor_data(ln_out_total, dgelu) @@ -1011,13 +1004,15 @@ class _LayerNormMLP(torch.autograd.Function): else: fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs if ctx.ub_bulk_wgrad: - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output # Column Parallel Linear - if ((not ctx.ub_bulk_wgrad) + if ( + (not ctx.ub_bulk_wgrad) and ctx.set_parallel_mode and ctx.tensor_parallel - and handle is not None): + and handle is not None + ): handle.wait() # LayerNorm gradient @@ -1032,13 +1027,22 @@ class _LayerNormMLP(torch.autograd.Function): if ctx.normalization == "LayerNorm": dgrad, dgamma, dbeta = tex.layernorm_bwd( - dgrad, inputmat, mu, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + dgrad, + inputmat, + mu, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, ) elif ctx.normalization == "RMSNorm": dgrad, dgamma = tex.rmsnorm_bwd( - dgrad, inputmat, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + dgrad, + inputmat, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, ) dbeta = None clear_tensor_data(mu) @@ -1046,20 +1050,22 @@ class _LayerNormMLP(torch.autograd.Function): if fc1_weight.requires_grad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'): + if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): fc1_weight.grad_added_to_main_grad = True - if getattr(fc1_weight, 'zero_out_wgrad', False): - fc1_wgrad = torch.zeros(fc1_weight.main_grad.shape, - dtype=fc1_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False - ) + if getattr(fc1_weight, "zero_out_wgrad", False): + fc1_wgrad = torch.zeros( + fc1_weight.main_grad.shape, + dtype=fc1_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) else: - fc1_wgrad = torch.empty(fc1_weight.main_grad.shape, - dtype=fc1_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False - ) + fc1_wgrad = torch.empty( + fc1_weight.main_grad.shape, + dtype=fc1_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: fc1_wgrad = None else: @@ -1067,20 +1073,22 @@ class _LayerNormMLP(torch.autograd.Function): if fc2_weight.requires_grad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'): + if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"): fc2_weight.grad_added_to_main_grad = True - if getattr(fc2_weight, 'zero_out_wgrad', False): - fc2_wgrad = torch.zeros(fc2_weight.main_grad.shape, - dtype=fc2_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False - ) + if getattr(fc2_weight, "zero_out_wgrad", False): + fc2_wgrad = torch.zeros( + fc2_weight.main_grad.shape, + dtype=fc2_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) else: - fc2_wgrad = torch.empty(fc2_weight.main_grad.shape, - dtype=fc2_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False - ) + fc2_wgrad = torch.empty( + fc2_weight.main_grad.shape, + dtype=fc2_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: fc2_wgrad = None else: @@ -1094,7 +1102,7 @@ class _LayerNormMLP(torch.autograd.Function): _fsdp_scatter_tensors( ctx.fsdp_group, fc1_weight_fp8 if not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None + fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None, ) return ( @@ -1246,8 +1254,8 @@ class LayerNormMLP(TransformerEngineBaseModule): tp_size: int = 1, init_method: Optional[Callable] = None, bias: bool = True, - normalization: str = 'LayerNorm', - activation : str = "gelu", + normalization: str = "LayerNorm", + activation: str = "gelu", output_layer_init_method: Optional[Callable] = None, fuse_wgrad_accumulation: bool = False, params_dtype: Optional[torch.dtype] = None, @@ -1269,15 +1277,16 @@ class LayerNormMLP(TransformerEngineBaseModule): params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.normalization = normalization - assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!" + assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" self.use_bias = bias self.activation = activation self.return_bias = return_bias self.apply_bias = bias and not return_bias self.return_layernorm_output = return_layernorm_output self.return_layernorm_output_gathered = return_layernorm_output_gathered - self.bias_gelu_nvfusion = (bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) and - self.activation == 'gelu') + self.bias_gelu_nvfusion = ( + bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) and self.activation == "gelu" + ) self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma self.ub_bulk_wgrad = ub_bulk_wgrad @@ -1286,9 +1295,11 @@ class LayerNormMLP(TransformerEngineBaseModule): self.ub_overlap_rs = ub_overlap_rs self.ub_overlap_ag = ub_overlap_ag # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap - self.gemm_gelu_fusion = \ - (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and - self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm()) + self.gemm_gelu_fusion = ( + bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) + and self.activation == "gelu" + and not get_ub("fc1_fprop").is_atomic_gemm() + ) if tp_group is None: self.tp_size = tp_size @@ -1312,42 +1323,42 @@ class LayerNormMLP(TransformerEngineBaseModule): # LN init self.eps = eps - layer_norm_weight = Parameter( - torch.empty(hidden_size, device=device, dtype=params_dtype) + layer_norm_weight = Parameter(torch.empty(hidden_size, device=device, dtype=params_dtype)) + self.register_parameter( + "layer_norm_weight", + layer_norm_weight, + init_fn=init_method_constant(float(not self.zero_centered_gamma)), ) - self.register_parameter('layer_norm_weight', layer_norm_weight, - init_fn=init_method_constant(float(not self.zero_centered_gamma))) if self.normalization != "RMSNorm": - layer_norm_bias = Parameter( - torch.empty(hidden_size, device=device, dtype=params_dtype) + layer_norm_bias = Parameter(torch.empty(hidden_size, device=device, dtype=params_dtype)) + self.register_parameter( + "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) ) - self.register_parameter('layer_norm_bias', layer_norm_bias, - init_fn=init_method_constant(0.0)) else: self.layer_norm_bias = None # FC1 init - if self.activation in ['reglu', 'geglu', 'swiglu']: + if self.activation in ["reglu", "geglu", "swiglu"]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition fc1_weight = Parameter( - torch.empty( - fc1_output_features, hidden_size, device=device, dtype=params_dtype - ) + torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype) + ) + self.register_parameter( + "fc1_weight", + fc1_weight, + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, ) - self.register_parameter('fc1_weight', fc1_weight, - init_fn=init_method, - get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) if self.use_bias: fc1_bias = Parameter( torch.empty(fc1_output_features, device=device, dtype=params_dtype) ) - self.register_parameter('fc1_bias', fc1_bias, - init_fn=init_method_constant(0.0)) + self.register_parameter("fc1_bias", fc1_bias, init_fn=init_method_constant(0.0)) else: self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device) @@ -1355,24 +1366,24 @@ class LayerNormMLP(TransformerEngineBaseModule): fc2_weight = Parameter( torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype) ) - self.register_parameter('fc2_weight', fc2_weight, - init_fn=output_layer_init_method, - get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT) + self.register_parameter( + "fc2_weight", + fc2_weight, + init_fn=output_layer_init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + ) if self.use_bias: - fc2_bias = Parameter( - torch.empty(hidden_size, device=device, dtype=params_dtype) - ) - self.register_parameter('fc2_bias', fc2_bias, - init_fn=init_method_constant(0.0)) + fc2_bias = Parameter(torch.empty(hidden_size, device=device, dtype=params_dtype)) + self.register_parameter("fc2_bias", fc2_bias, init_fn=init_method_constant(0.0)) else: self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device) if with_fp8_params: self.init_fp8_metadata(num_gemms=2) - self.reset_parameters(defer_init=(device == 'meta')) + self.reset_parameters(defer_init=(device == "meta")) # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM @@ -1399,10 +1410,10 @@ class LayerNormMLP(TransformerEngineBaseModule): def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( - ("This method will be deprecated in an upcoming release. " - "Update your code to use LayerNormMLP.reset_parameters() instead."), + "This method will be deprecated in an upcoming release. " + "Update your code to use LayerNormMLP.reset_parameters() instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if not self.zero_centered_gamma: init.ones_(self.layer_norm_weight) @@ -1430,9 +1441,7 @@ class LayerNormMLP(TransformerEngineBaseModule): @no_torch_dynamo() def forward( - self, - inp: torch.Tensor, - is_first_microbatch: Optional[bool] = None + self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1475,10 +1484,7 @@ class LayerNormMLP(TransformerEngineBaseModule): fc1_weight_fp8 = None fc2_weight_fp8 = None if self.fp8: - update_workspace = ( - is_first_microbatch is None - or is_first_microbatch - ) + update_workspace = is_first_microbatch is None or is_first_microbatch with_transpose = torch.is_grad_enabled() if ( is_fp8_activation_recompute_enabled() @@ -1488,8 +1494,7 @@ class LayerNormMLP(TransformerEngineBaseModule): update_transpose_cache = with_transpose if update_transpose_cache: update_transpose_cache = ( - is_first_microbatch - or skip_fp8_weight_update is not None + is_first_microbatch or skip_fp8_weight_update is not None ) if isinstance(fc1_weight, Float8Tensor): if update_transpose_cache: @@ -1531,8 +1536,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ) # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode - if (self.bias_gelu_nvfusion - and not use_reentrant_activation_recompute()): + if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): self.bias_gelu_nvfusion = False from ..cpu_offload import CPUOffloadEnabled diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 62753538c416569edcbb7cba45ab080ff3ccfa4b..79ec1eed5ad64e020bab31d2f36c446bf187a165 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -148,13 +148,9 @@ class _Linear(torch.autograd.Function): inputmat_total = inputmat if fp8: if _NVTE_DEBUG: - print('[Linear]: using FP8 forward') + print("[Linear]: using FP8 forward") - bias_dtype = ( - torch.bfloat16 - if activation_dtype == torch.float32 - else activation_dtype - ) + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias # Use FP8 weights @@ -168,13 +164,18 @@ class _Linear(torch.autograd.Function): tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], fp8_dtype_forward, - torch.uint8) + torch.uint8, + ) else: proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( - None, None, None, activation_dtype) + None, + None, + None, + activation_dtype, + ) if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name+"_fprop") + ub_obj_projout = get_ub(ub_name + "_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size @@ -182,7 +183,7 @@ class _Linear(torch.autograd.Function): rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_atomic_gemm(): - ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: @@ -206,8 +207,11 @@ class _Linear(torch.autograd.Function): weight_fp8._scale_inv, 0, weight_fp8._fp8_dtype, - inputmat_total._data - if isinstance(inputmat_total, Float8Tensor) else inputmat_total, + ( + inputmat_total._data + if isinstance(inputmat_total, Float8Tensor) + else inputmat_total + ), fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, @@ -221,11 +225,12 @@ class _Linear(torch.autograd.Function): ub=ub_obj_projout if ub_overlap_rs else None, extra_output_tensor=rs_out if ub_overlap_rs else None, out_index=proj_out_index, - fp8_meta_tensor = meta_tensor, - D_dtype = proj_out_tetype, + fp8_meta_tensor=meta_tensor, + D_dtype=proj_out_tetype, ) if is_first_module_in_mha: - out = Float8Tensor(data=out, + out = Float8Tensor( + data=out, fp8_meta=fp8_meta, fp8_meta_forward=True, fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, @@ -234,7 +239,7 @@ class _Linear(torch.autograd.Function): ) else: if _NVTE_DEBUG: - print('[Linear]: using non-FP8 forward') + print("[Linear]: using non-FP8 forward") # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) @@ -243,15 +248,17 @@ class _Linear(torch.autograd.Function): if fp8_calibration: # amax of input amin, amax = inputmat_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \ - torch.max(-amin, amax).float() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( + -amin, amax + ).float() # amax of weight amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ - torch.max(-amin, amax).float() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( + -amin, amax + ).float() if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name+"_fprop") + ub_obj_projout = get_ub(ub_name + "_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) @@ -308,8 +315,8 @@ class _Linear(torch.autograd.Function): ctx.fsdp_group = fsdp_group ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, - saved_inputmat, # None if fp8 == False - saved_inputmat_t, # None if fp8 == False AND not is_grad_enabled + saved_inputmat, # None if fp8 == False + saved_inputmat_t, # None if fp8 == False AND not is_grad_enabled weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None, ) @@ -342,8 +349,9 @@ class _Linear(torch.autograd.Function): ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, weight, bias): ctx.reduce_and_update_bwd_fp8_tensors = ( - ctx.reduce_and_update_bwd_fp8_tensors or - FP8GlobalStateManager.is_first_fp8_module()) + ctx.reduce_and_update_bwd_fp8_tensors + or FP8GlobalStateManager.is_first_fp8_module() + ) # Row Parallel Linear if ub_overlap_rs: @@ -356,14 +364,12 @@ class _Linear(torch.autograd.Function): # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) - @staticmethod - def backward( - ctx, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: if isinstance(grad_output, Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[ - tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_output._scale_inv + tex.FP8BwdTensors.GRAD_OUTPUT1 + ] = grad_output._scale_inv with torch.cuda.nvtx.range("_Linear_backward"): ( @@ -383,7 +389,8 @@ class _Linear(torch.autograd.Function): ctx.fsdp_shapes, inputmat, inputmat_t, - weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None) + weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None, + ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: weight = torch.nn.Parameter(weight, False) @@ -394,7 +401,7 @@ class _Linear(torch.autograd.Function): if ctx.ub_overlap_ag: dim_size = list(grad_output.size()) dim_size[0] = dim_size[0] * tp_world_size - ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P else: @@ -430,27 +437,28 @@ class _Linear(torch.autograd.Function): accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=True - ) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) if ctx.requires_dgrad: if ctx.fp8: if _NVTE_DEBUG: - print('[Linear]: using FP8 backward') + print("[Linear]: using FP8 backward") if ctx.is_input_fp8: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8BwdTensors.GRAD_INPUT1, ctx.fp8_meta["scaling_bwd"], fp8_dtype_backward, - torch.uint8) + torch.uint8, + ) else: out_index, meta_tensor, output_te_dtype, output_dtype = ( - None, None, None, ctx.activation_dtype) + None, + None, + None, + ctx.activation_dtype, + ) dgrad, _ = fp8_gemm( weight_fp8.transpose_2d(), weight_fp8._scale_inv, @@ -470,16 +478,17 @@ class _Linear(torch.autograd.Function): D_dtype=output_te_dtype, ) if output_dtype == torch.uint8: - dgrad = Float8Tensor(data=dgrad, + dgrad = Float8Tensor( + data=dgrad, fp8_meta=ctx.fp8_meta, fp8_meta_forward=False, fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1, fp8_dtype=fp8_dtype_backward, dtype=ctx.activation_dtype, - ) + ) else: if _NVTE_DEBUG: - print('[Linear]: using non-FP8 backward') + print("[Linear]: using non-FP8 backward") dgrad, _, _ = gemm( weight, @@ -488,8 +497,11 @@ class _Linear(torch.autograd.Function): get_workspace(), layout="NN", grad=True, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \ - if ctx.ub_overlap_ag else None, + ub_algo=( + tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + if ctx.ub_overlap_ag + else None + ), ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) @@ -517,10 +529,14 @@ class _Linear(torch.autograd.Function): inputmat_t_total = inputmat_total.transpose_2d() else: inputmat_t_total = tex.fp8_transpose( - inputmat_total, fp8_dtype_backward) + inputmat_total, fp8_dtype_backward + ) wgrad, _ = fp8_gemm( - inputmat_t_total._data - if isinstance(inputmat_t_total, Float8Tensor) else inputmat_t_total, + ( + inputmat_t_total._data + if isinstance(inputmat_t_total, Float8Tensor) + else inputmat_t_total + ), fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, @@ -572,20 +588,22 @@ class _Linear(torch.autograd.Function): if weight.requires_grad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'): + if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): weight.grad_added_to_main_grad = True - if getattr(weight, 'zero_out_wgrad', False): - wgrad = torch.zeros(weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False - ) + if getattr(weight, "zero_out_wgrad", False): + wgrad = torch.zeros( + weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) else: - wgrad = torch.empty(weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False - ) + wgrad = torch.empty( + weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: wgrad = None else: @@ -733,9 +751,8 @@ class Linear(TransformerEngineBaseModule): self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - if device == 'meta': - assert parameters_split is None, ("Cannot split module parameters " - "on 'meta' device.") + if device == "meta": + assert parameters_split is None, "Cannot split module parameters on 'meta' device." if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -833,10 +850,7 @@ class Linear(TransformerEngineBaseModule): # Check if parameters are subviews of buffers is_subview = (split_start, split_end) != (0, self.out_features) if is_subview and with_fp8_params: - raise RuntimeError( - "Splitting Float8Tensor into multiple params " - "is not supported" - ) + raise RuntimeError("Splitting Float8Tensor into multiple params is not supported") # Construct weight parameter self.register_parameter( @@ -867,7 +881,7 @@ class Linear(TransformerEngineBaseModule): if with_fp8_params: self.init_fp8_metadata() - self.reset_parameters(defer_init=(device == 'meta')) + self.reset_parameters(defer_init=(device == "meta")) # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM @@ -933,7 +947,7 @@ class Linear(TransformerEngineBaseModule): with self.prepare_forward( inp, is_first_microbatch, - allow_non_contiguous=isinstance(inp,Float8Tensor), + allow_non_contiguous=isinstance(inp, Float8Tensor), ) as inp: is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha @@ -944,8 +958,7 @@ class Linear(TransformerEngineBaseModule): if self.fp8: if len(unfused_weights) != 1: raise RuntimeError( - "Splitting Float8Tensor into multiple params " - "is not supported" + "Splitting Float8Tensor into multiple params is not supported" ) else: unfused_weights = [w.from_float8() for w in unfused_weights] @@ -972,8 +985,7 @@ class Linear(TransformerEngineBaseModule): update_transpose_cache = with_transpose if update_transpose_cache: update_transpose_cache = ( - is_first_microbatch - or skip_fp8_weight_update is not None + is_first_microbatch or skip_fp8_weight_update is not None ) if update_transpose_cache: weight_tensor.transpose_2d( @@ -982,10 +994,7 @@ class Linear(TransformerEngineBaseModule): ) else: # FP8 cast to workspace buffer - update_workspace = ( - is_first_microbatch is None - or is_first_microbatch - ) + update_workspace = is_first_microbatch is None or is_first_microbatch weight_fp8 = self.get_fp8_workspace( tensor=weight_tensor, fp8_meta_forward=True, diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index e1d2ac2551ef838eaf9e3ec4130195deef59d28c..969a468426f196b6a3211cb282a7a218b853b422 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -47,29 +47,31 @@ class _RMSNorm(torch.autograd.Function): rmsnorm_weight = cast_if_needed(rmsnorm_weight, activation_dtype) if is_grad_enabled: - rmsnorm_out, rsigma = tex.rmsnorm_fwd(inputmat, rmsnorm_weight, - eps, fwd_rmsnorm_sm_margin, - zero_centered_gamma) + rmsnorm_out, rsigma = tex.rmsnorm_fwd( + inputmat, rmsnorm_weight, eps, fwd_rmsnorm_sm_margin, zero_centered_gamma + ) ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma) ctx.inp_shape = inp.shape ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin ctx.zero_centered_gamma = zero_centered_gamma else: - rmsnorm_out = tex.rmsnorm_fwd_inf(inputmat, rmsnorm_weight, - eps, inf_rmsnorm_sm_margin, - zero_centered_gamma) + rmsnorm_out = tex.rmsnorm_fwd_inf( + inputmat, rmsnorm_weight, eps, inf_rmsnorm_sm_margin, zero_centered_gamma + ) return rmsnorm_out.view_as(inp) @staticmethod - def backward( - ctx, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors grad_output = grad_output.contiguous() d_rmsnorm_out = grad_output.view(inputmat.shape) dxmat, dgamma = tex.rmsnorm_bwd( - d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight, - ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma + d_rmsnorm_out, + inputmat, + rsigma, + rmsnorm_weight, + ctx.bwd_rmsnorm_sm_margin, + ctx.zero_centered_gamma, ) return ( dxmat.view(ctx.inp_shape), @@ -145,7 +147,7 @@ class RMSNorm(torch.nn.Module): ) self.sequence_parallel = sequence_parallel - self.reset_parameters(defer_init=(device == 'meta')) + self.reset_parameters(defer_init=(device == "meta")) # These many SMs are subtracted from the total SM count when calling forward # and backward RMSNorm C APIs. These envvars can be used to prevent the LN @@ -158,10 +160,10 @@ class RMSNorm(torch.nn.Module): def reset_rms_norm_parameters(self) -> None: """Init RMSNorm params""" warnings.warn( - ("This method is deprecated and will be removed in an upcoming release. " - "Update your code to use RMSNorm.reset_parameters() instead."), + "This method is deprecated and will be removed in an upcoming release. " + "Update your code to use RMSNorm.reset_parameters() instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if not self.zero_centered_gamma: init.ones_(self.weight) @@ -173,8 +175,8 @@ class RMSNorm(torch.nn.Module): if defer_init: return - if self.weight.device == torch.device('meta'): - self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device='cuda')) + if self.weight.device == torch.device("meta"): + self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda")) init.constant_(self.weight, float(not self.zero_centered_gamma)) setattr(self.weight, "sequence_parallel", self.sequence_parallel) diff --git a/transformer_engine/pytorch/numerics_debug.py b/transformer_engine/pytorch/numerics_debug.py index d28fdfe2adf88683378b631f5391e0984951b771..bc9a5f89e071dbd91038ac2c4d78192d94c7c43b 100644 --- a/transformer_engine/pytorch/numerics_debug.py +++ b/transformer_engine/pytorch/numerics_debug.py @@ -16,9 +16,7 @@ def debug(enabled: bool = True) -> None: _NUMERICS_DEBUG = enabled -def fp8_tensor_statistics( - tensor: torch.Tensor, fp8_format: str = "E4M3" -) -> Tuple[int, ...]: +def fp8_tensor_statistics(tensor: torch.Tensor, fp8_format: str = "E4M3") -> Tuple[int, ...]: """Print FP8 tensor stats""" fp8_format = fp8_format.upper() assert fp8_format in ( diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 0e069df48270d51dfa0e5ef536cba6deedb25fa6..91ce50239072d801ef2552d3812890529f507926 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -124,9 +124,7 @@ class FusedAdam(torch.optim.Optimizer): self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") self.multi_tensor_adam = tex.multi_tensor_adam self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable - self.multi_tensor_adam_capturable_master = ( - tex.multi_tensor_adam_capturable_master - ) + self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master def zero_grad(self): if self.set_grad_none: @@ -160,15 +158,11 @@ class FusedAdam(torch.optim.Optimizer): # per parameter step can be easily support by making it tensor, or pass list into kernel if "step" in group: group["step"] += ( - 1 - if not self.capturable - else (self._dummy_overflow_buf != 1).to(torch.int) + 1 if not self.capturable else (self._dummy_overflow_buf != 1).to(torch.int) ) else: group["step"] = ( - 1 - if not self.capturable - else torch.tensor([1], dtype=torch.int, device=device) + 1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device) ) # create lists for multi-tensor apply @@ -182,9 +176,7 @@ class FusedAdam(torch.optim.Optimizer): if p.grad is None: continue if p.grad.data.is_sparse: - raise RuntimeError( - "FusedAdam does not support sparse gradients." - ) + raise RuntimeError("FusedAdam does not support sparse gradients.") state = self.state[p] # State initialization diff --git a/transformer_engine/pytorch/optimizers/fused_sgd.py b/transformer_engine/pytorch/optimizers/fused_sgd.py index 4da3392759e839ba54ffce672180da8a2c755e69..6186f3f3ea4d851c6517d9263fea544c4cdc773e 100644 --- a/transformer_engine/pytorch/optimizers/fused_sgd.py +++ b/transformer_engine/pytorch/optimizers/fused_sgd.py @@ -182,13 +182,9 @@ class FusedSGD(Optimizer): if explicit_master_params: stash = self._amp_stash - fp32_params = [ - p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None - ] + fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] fp32_grads = [ - p.grad - for p in stash.fp32_from_fp32_groups[gid] - if p.grad is not None + p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None ] fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) @@ -199,14 +195,10 @@ class FusedSGD(Optimizer): if stash.fp32_from_fp16_groups[gid][i].grad is not None ] fp32_from_fp16_grads = [ - p.grad - for p in stash.fp32_from_fp16_groups[gid] - if p.grad is not None + p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None ] fp32_from_fp16_params = [ - p - for p in stash.fp32_from_fp16_groups[gid] - if p.grad is not None + p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None ] fp32_from_fp16_momentums, first_runs[0] = self.get_momentums( fp32_from_fp16_params @@ -219,9 +211,7 @@ class FusedSGD(Optimizer): fp16_model_params, ] else: - fp16_model_params = [ - p for p in stash.fp16_groups[gid] if p.grad is not None - ] + fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None] fp16_model_grads = [ p.grad for p in stash.fp16_groups[gid] if p.grad is not None ] @@ -244,9 +234,7 @@ class FusedSGD(Optimizer): launch_sets = [fp16_set, [fp32_grads, fp32_params, fp32_momentums]] else: fp16_params = [ - p - for p in group["params"] - if (p.dtype == torch.float16 and p.grad is not None) + p for p in group["params"] if (p.dtype == torch.float16 and p.grad is not None) ] fp16_grads = [ p.grad @@ -256,9 +244,7 @@ class FusedSGD(Optimizer): fp16_momentums, first_runs[0] = self.get_momentums(fp16_params) fp32_params = [ - p - for p in group["params"] - if (p.dtype == torch.float32 and p.grad is not None) + p for p in group["params"] if (p.dtype == torch.float32 and p.grad is not None) ] fp32_grads = [ p.grad diff --git a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py index e9b4b4e9f3c16abaed06a411e5c332417c012a28..b8d6d1f2639211399f10308176bc1e79ddb94b86 100644 --- a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py +++ b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py @@ -3,6 +3,8 @@ # See LICENSE for license information. """Multi-tensor apply entry.""" + + class MultiTensorApply: # pylint: disable=too-few-public-methods """Multi-tensor apply entry.""" diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index d8e35dd5882bd908043172bf8154904b5588b6e5..b02b219fec9044668768cd85de33f4df61a3b6a4 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -38,12 +38,12 @@ CMakeBuildExtension = get_build_ext(BuildExtension) if __name__ == "__main__": # Extensions common_headers_dir = "common_headers" - copy_common_headers( - current_file_path.parent, - str(current_file_path / common_headers_dir)) + copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) ext_modules = [ setup_pytorch_extension( - "csrc", current_file_path / "csrc", current_file_path / common_headers_dir)] + "csrc", current_file_path / "csrc", current_file_path / common_headers_dir + ) + ] # Configure package setuptools.setup( @@ -56,9 +56,11 @@ if __name__ == "__main__": install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"], tests_require=["numpy", "onnxruntime", "torchvision"], include_package_data=True, - package_data={"csrc": package_files("csrc"), - common_headers_dir: package_files(common_headers_dir), - "build_tools": package_files("build_tools")}, + package_data={ + "csrc": package_files("csrc"), + common_headers_dir: package_files(common_headers_dir), + "build_tools": package_files("build_tools"), + }, ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 876f309fc148238908b3554aa36d95f1b92a1444..412486e4d3c9ab2b6f9c858e38f21362c99581e0 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -20,6 +20,7 @@ THREADS_PER_BLOCK = 128 _default_causal_mask = {} + def _get_default_causal_mask(sq: int, sk: int) -> torch.Tensor: """Return the causal upper triangular mask for softmax input""" if sq == 1: @@ -29,8 +30,8 @@ def _get_default_causal_mask(sq: int, sk: int) -> torch.Tensor: if matrix_shape not in _default_causal_mask: diagonal_offset = sk - sq + 1 _default_causal_mask[matrix_shape] = torch.triu( - torch.ones(sq, sk, dtype=torch.bool, device="cuda"), - diagonal=diagonal_offset) + torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset + ) return _default_causal_mask[matrix_shape] @@ -49,15 +50,17 @@ def _get_onnx_export_causal_mask( """ assert len(onnx_causal_mask.size()) == 2 assert onnx_causal_mask.size(0) == onnx_causal_mask.size(1) - assert onnx_causal_mask.size(0) >= (seq_k-seq_q) >= 0 - derived_mask = onnx_causal_mask[seq_k-seq_q:seq_k, :seq_k] + assert onnx_causal_mask.size(0) >= (seq_k - seq_q) >= 0 + derived_mask = onnx_causal_mask[seq_k - seq_q : seq_k, :seq_k] return derived_mask def fp32_compute(onnx_symbolic_fn): """A decorator that wraps an ONNX symoblic function with FP32 compute operators.""" + def wrapper(g: torch.Graph, inp: torch._C.Value, scale: float, *args, **kwargs): return compute_in_fp32(g, inp, onnx_symbolic_fn, scale, *args, **kwargs) + return wrapper @@ -73,17 +76,13 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: """ScaledUpperTriangMaskedSoftmax fwd""" scale_t = torch.tensor([scale]) - softmax_results = tex.scaled_upper_triang_masked_softmax_forward( - inputs, scale_t[0] - ) + softmax_results = tex.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @staticmethod - def backward( - ctx, output_grads: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """ScaledUpperTriangMaskedSoftmax bwd""" softmax_results, scale_t = ctx.saved_tensors input_grads = tex.scaled_upper_triang_masked_softmax_backward( @@ -96,8 +95,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @fp32_compute def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: """ScaledUpperTriangMaskedSoftmax symbolic method""" + def triangular_mask(): - dtype = _type_utils.JitScalarType.INT64 + dtype = _type_utils.JitScalarType.INT64 ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) mask = g.op("Trilu", ones, k, upper_i=1) @@ -109,7 +109,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) inv_mask = g.op("Sub", one, mask) - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16)) + neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) softmax_mask = g.op("Mul", mask, neg_tenK) scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) @@ -132,16 +132,12 @@ class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: """ScaledAlignedCausalMaskedSoftmax fwd""" scale_t = torch.tensor([scale]) - softmax_results = tex.scaled_aligned_causal_masked_softmax_forward( - inputs, scale_t[0] - ) + softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @staticmethod - def backward( - ctx, output_grads: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """ScaledAlignedCausalMaskedSoftmax bwd""" softmax_results, scale_t = ctx.saved_tensors input_grads = tex.scaled_aligned_causal_masked_softmax_backward( @@ -154,8 +150,9 @@ class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): @fp32_compute def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: """ScaledAlignedCausalMaskedSoftmax symbolic method""" + def triangular_mask(): - dtype = _type_utils.JitScalarType.INT64 + dtype = _type_utils.JitScalarType.INT64 ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) @@ -173,7 +170,7 @@ class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) inv_mask = g.op("Sub", one, mask) - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16)) + neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) softmax_mask = g.op("Mul", mask, neg_tenK) scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) @@ -193,9 +190,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): """ @staticmethod - def forward( - ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float - ) -> torch.Tensor: + def forward(ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float) -> torch.Tensor: """ScaledMaskedSoftmax fwd""" scale_t = torch.tensor([scale]) @@ -204,24 +199,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function): return softmax_results @staticmethod - def backward( - ctx, output_grads: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """ScaledMaskedSoftmax bwd""" softmax_results, scale_t = ctx.saved_tensors - input_grads = tex.scaled_masked_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) + input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None @staticmethod @fp32_compute def symbolic( - g: torch.Graph, - inputs: torch._C.Value, - mask: torch._C.Value, - scale: float) -> torch._C.Value: + g: torch.Graph, inputs: torch._C.Value, mask: torch._C.Value, scale: float + ) -> torch._C.Value: """ScaledMaskedSoftmax symbolic method""" # Captures the logic of function scaled_masked_softmax_warp_forward. # output = softmax(mask(input*scale) @@ -234,7 +223,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) inv_mask = g.op("Sub", one, mask) # Note: type is hard coded because softmax uses FP16 or BF16 - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16)) + neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) softmax_mask = g.op("Mul", mask, neg_tenK) masked_scaled = g.op("Mul", inv_mask, scaled) masked = g.op("Add", masked_scaled, softmax_mask) @@ -259,15 +248,11 @@ class ScaledSoftmax(torch.autograd.Function): return softmax_results @staticmethod - def backward( - ctx, output_grads: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """ScaledSoftmax bwd""" softmax_results, scale_t = ctx.saved_tensors - input_grads = tex.scaled_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) + input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None @staticmethod @@ -280,7 +265,6 @@ class ScaledSoftmax(torch.autograd.Function): return out - class FusedScaleMaskSoftmax(nn.Module): """ fused operation: scaling + mask + softmax @@ -296,9 +280,7 @@ class FusedScaleMaskSoftmax(nn.Module): softmax_in_fp32: bool = True, ) -> None: super().__init__() - self.scaled_masked_softmax_fusion = bool( - int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1")) - ) + self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))) self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 @@ -309,9 +291,10 @@ class FusedScaleMaskSoftmax(nn.Module): "onnx_causal_mask", torch.triu( torch.ones(self.kvcache_max_seq, self.kvcache_max_seq, device="cuda"), - diagonal=1 + diagonal=1, ).bool(), - persistent=False) + persistent=False, + ) def forward( self, @@ -328,9 +311,7 @@ class FusedScaleMaskSoftmax(nn.Module): self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type - assert ( - scale is None or self.softmax_in_fp32 - ), "softmax should be in fp32 when scaled" + assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode(): return self.forward_fused_softmax(inp, mask, scale) @@ -351,11 +332,12 @@ class FusedScaleMaskSoftmax(nn.Module): if self.attn_mask_type == "arbitrary": return False # Custom masks not supported - if self.attn_mask_type == "causal": # unfused causal softmax kernel + if self.attn_mask_type == "causal": # unfused causal softmax kernel return True - if (sq % 4 == 0 # sq must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 + if ( + sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 and self.attn_mask_type != "arbitrary" # Custom masks not supported ): batch_per_block = self.get_batch_per_block(int(sk)) diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index f485249c7d276c5504248a402245cb10c1f73d1a..05c1a5a0f5027433f1502a5caff0b003843ddfa6 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -84,9 +84,9 @@ def quantize(g, inputs, scale_inv, fp8_tensor): inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) - q_op = g.op( - make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( - inputs.type().with_dtype(torch.uint8).with_sizes(output_shape)) + q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( + inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) + ) return q_op @@ -96,7 +96,8 @@ def dequantize(g, inputs, scale_inv, fp8_tensor, otype): scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) out = g.op(make_op_name("TRT_FP8DequantizeLinear"), inputs, scale).setType( - inputs.type().with_dtype(torch.float32).with_sizes(output_shape)) + inputs.type().with_dtype(torch.float32).with_sizes(output_shape) + ) # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT # custom ops, so cast the output if needed. @@ -230,10 +231,30 @@ def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): return geglu -@symbolic_helper.parse_args("v", "fs", "i", "i", "i", - "v", "fs", "i", "i", "i", - "v", "fs", "i", "fs", "v", "i", "v", "i", - "v", "i", "i", "i") +@symbolic_helper.parse_args( + "v", + "fs", + "i", + "i", + "i", + "v", + "fs", + "i", + "i", + "i", + "v", + "fs", + "i", + "fs", + "v", + "i", + "v", + "i", + "v", + "i", + "i", + "i", +) def onnx_te_gemm( g, weight, @@ -257,7 +278,8 @@ def onnx_te_gemm( workspace, workspaceSize, accumulate, - use_split_accumulator): + use_split_accumulator, +): """ONNX graph for te_gemm""" # pylint: disable=unused-argument is_fp16 = is_dtype_fp16(inputs) @@ -270,8 +292,9 @@ def onnx_te_gemm( empty_tensor_size = [0] bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size - pre_gelu_out_empty = torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) \ - == empty_tensor_size + pre_gelu_out_empty = ( + torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) == empty_tensor_size + ) if not bias_empty: output = g.op("Gemm", inputs, weight, bias, transA_i=trans_input, transB_i=trans_weight) @@ -297,16 +320,31 @@ def _ones_like(g, inp, dtype): # WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR # create a ConstantOfShape with type FP32 and then add a Cast to BF16. is_bf16 = dtype == torch.bfloat16 - one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1], - dtype=torch.float32 if is_bf16 else dtype)) + one = g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([1], dtype=torch.float32 if is_bf16 else dtype), + ) if is_bf16: one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) return one @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") -def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, - scale_inv, fp8_tensor, otype, sm_margin, zero_centered_gamma): +def onnx_layernorm_fwd_fp8( + g, + inputs, + weight, + bias, + eps, + scale, + amax, + scale_inv, + fp8_tensor, + otype, + sm_margin, + zero_centered_gamma, +): """ONNX graph for layernorm_fwd_fp8""" # pylint: disable=unused-argument inp_dtype = get_TensorProtoDataType(inputs) @@ -340,7 +378,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga weight = g.op("Add", weight, one) axis = -len(normalized_shape) - ln = g.op( + ln = g.op( "LayerNormalization", inputs, weight, @@ -352,9 +390,21 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga ) return ln + @symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") -def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax, - scale_inv, fp8_tensor, otype, sm_margin, zero_centered_gamma): +def onnx_rmsnorm_fwd_fp8( + g, + inputs, + weight, + eps, + scale, + amax, + scale_inv, + fp8_tensor, + otype, + sm_margin, + zero_centered_gamma, +): """ONNX graph for rmsnorm_fwd_fp8""" # pylint: disable=unused-argument inp_dtype = get_TensorProtoDataType(inputs) @@ -403,16 +453,16 @@ def onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma): return result -register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER) -register_custom_op_symbolic('tex_ts::cast_to_fp8_noalloc_ts', onnx_cast_to_fp8_noalloc, VER) -register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER) -register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER) -register_custom_op_symbolic('tex_ts::relu_ts', onnx_fp8_relu, VER) -register_custom_op_symbolic('tex_ts::reglu_ts', onnx_fp8_reglu, VER) -register_custom_op_symbolic('tex_ts::geglu_ts', onnx_fp8_geglu, VER) -register_custom_op_symbolic('tex_ts::swiglu_ts', onnx_fp8_swiglu, VER) -register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER) -register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER) -register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER) -register_custom_op_symbolic('tex_ts::rmsnorm_fwd_fp8_inf_ts', onnx_rmsnorm_fwd_fp8, VER) -register_custom_op_symbolic('tex_ts::rmsnorm_fwd_inf_ts', onnx_rmsnorm_fwd, VER) +register_custom_op_symbolic("tex_ts::cast_to_fp8_ts", onnx_cast_to_fp8, VER) +register_custom_op_symbolic("tex_ts::cast_to_fp8_noalloc_ts", onnx_cast_to_fp8_noalloc, VER) +register_custom_op_symbolic("tex_ts::cast_from_fp8_ts", onnx_cast_from_fp8, VER) +register_custom_op_symbolic("tex_ts::gelu_ts", onnx_fp8_gelu, VER) +register_custom_op_symbolic("tex_ts::relu_ts", onnx_fp8_relu, VER) +register_custom_op_symbolic("tex_ts::reglu_ts", onnx_fp8_reglu, VER) +register_custom_op_symbolic("tex_ts::geglu_ts", onnx_fp8_geglu, VER) +register_custom_op_symbolic("tex_ts::swiglu_ts", onnx_fp8_swiglu, VER) +register_custom_op_symbolic("tex_ts::te_gemm_ts", onnx_te_gemm, VER) +register_custom_op_symbolic("tex_ts::layernorm_fwd_fp8_inf_ts", onnx_layernorm_fwd_fp8, VER) +register_custom_op_symbolic("tex_ts::layernorm_fwd_inf_ts", onnx_layernorm_fwd, VER) +register_custom_op_symbolic("tex_ts::rmsnorm_fwd_fp8_inf_ts", onnx_rmsnorm_fwd_fp8, VER) +register_custom_op_symbolic("tex_ts::rmsnorm_fwd_inf_ts", onnx_rmsnorm_fwd, VER) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index ddeee548dd7d2c4333ce4f7a744da9bf16e9b44d..4a0854cdfff25899d2c8e3a76a10a58fc0f26ced 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -263,7 +263,7 @@ class TransformerLayer(torch.nn.Module): ub_overlap_rs: bool = True, ub_overlap_rs_dgrad: bool = False, bias: bool = True, - activation: str = 'gelu', + activation: str = "gelu", normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", @@ -271,9 +271,7 @@ class TransformerLayer(torch.nn.Module): super().__init__() if ub_tp_comm_overlap: - assert ( - tex.userbuf_comm_available() - ), "Userbuffer communication backend not available." + assert tex.userbuf_comm_available(), "Userbuffer communication backend not available." self.self_attn_mask_type = self_attn_mask_type self.window_size = window_size @@ -289,16 +287,14 @@ class TransformerLayer(torch.nn.Module): self.layer_number = layer_number self.output_layernorm = output_layernorm self.layer_type = layer_type - self.apply_residual_connection_post_layernorm = ( - apply_residual_connection_post_layernorm - ) + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm if parallel_attention_mlp: assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'" - assert ( - not self.apply_residual_connection_post_layernorm - ), "parallel_attention and apply_residual_connection_post_layernorm "\ - "not supported simultaneously." + assert not self.apply_residual_connection_post_layernorm, ( + "parallel_attention and apply_residual_connection_post_layernorm " + "not supported simultaneously." + ) assert ( not self.output_layernorm ), "parallel_attention and output_layernorm not supported simultaneously" @@ -315,9 +311,7 @@ class TransformerLayer(torch.nn.Module): if not fuse_qkv_params: qkv_weight_interleaved = False - self.kv_channels = ( - kv_channels if kv_channels else (hidden_size // num_attention_heads) - ) + self.kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) if init_method is None: init_method = get_default_init_method() @@ -354,13 +348,13 @@ class TransformerLayer(torch.nn.Module): "set_parallel_mode": set_parallel_mode, "fuse_qkv_params": fuse_qkv_params, "zero_centered_gamma": zero_centered_gamma, - "qkv_weight_interleaved" : qkv_weight_interleaved, - "ub_bulk_wgrad" : ub_bulk_wgrad, - "ub_bulk_dgrad" : ub_bulk_dgrad, - "ub_overlap_ag" : ub_overlap_ag, - "ub_overlap_rs" : ub_overlap_rs, - "ub_overlap_rs_dgrad" : ub_overlap_rs_dgrad, - "qkv_format" : self.attn_input_format, + "qkv_weight_interleaved": qkv_weight_interleaved, + "ub_bulk_wgrad": ub_bulk_wgrad, + "ub_bulk_dgrad": ub_bulk_dgrad, + "ub_overlap_ag": ub_overlap_ag, + "ub_overlap_rs": ub_overlap_rs, + "ub_overlap_rs_dgrad": ub_overlap_rs_dgrad, + "qkv_format": self.attn_input_format, } self.self_attention = MultiheadAttention( @@ -429,22 +423,18 @@ class TransformerLayer(torch.nn.Module): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) - self.bias_dropout_add_exec_handler = ( - nullcontext if use_nvfuser else torch.enable_grad - ) + self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad if self.bias_dropout_fusion: set_jit_fusion_options() if seq_length and micro_batch_size: if self.sequence_parallel: seq_length = seq_length // self.tp_size - warmup_jit_bias_dropout_add_all_dtypes( - hidden_size, seq_length, micro_batch_size - ) + warmup_jit_bias_dropout_add_all_dtypes(hidden_size, seq_length, micro_batch_size) norm_module = { - "LayerNorm": LayerNorm, - "RMSNorm": RMSNorm, + "LayerNorm": LayerNorm, + "RMSNorm": RMSNorm, } if self.output_layernorm: self.layernorm = norm_module[normalization]( @@ -614,18 +604,14 @@ class TransformerLayer(torch.nn.Module): hidden_states.shape[0] == self.seq_length // self.tp_size ), "Sequence dimension must be split across TP group when using sequence parallel." - if (("padding" in self_attn_mask_type - or self_attn_mask_type == "arbitrary") - and attention_mask is not None): - assert ( - attention_mask.dtype == torch.bool - ), "Attention mask must be a boolean tensor" + if ( + "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" + ) and attention_mask is not None: + assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor" # For AMP if torch.is_autocast_enabled(): - hidden_states = cast_if_needed( - hidden_states, torch.get_autocast_gpu_dtype() - ) + hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) # Self attention. self_attention_outputs = self.self_attention( @@ -709,9 +695,7 @@ class TransformerLayer(torch.nn.Module): bias_dropout_add_func = get_bias_dropout_add(self.training) with self.bias_dropout_add_exec_handler(): - output = bias_dropout_add_func( - hidden_state, bias, residual, self.hidden_dropout - ) + output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout) else: if bias.numel() != 0: hidden_state = hidden_state + bias diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index de2ce218f28fa6798860e5cc51263cde7c60d594..e83369c67173870e93f687b2b7f3d1cedb44a29d 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -26,6 +26,7 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: Must be used carefully. """ from .float8_tensor import Float8Tensor + for t in tensors: if t is not None: if isinstance(t, Float8Tensor): @@ -58,12 +59,17 @@ def get_default_init_method() -> Callable: def init_method_constant(val: float) -> Callable: """Init method to set all tensor elements to a constant value.""" if val == 1.0: + def init_(tensor: torch.Tensor) -> Callable: return torch.nn.init.ones_(tensor) + elif val == 0.0: + def init_(tensor: torch.Tensor) -> Callable: return torch.nn.init.zeros_(tensor) + else: + def init_(tensor: torch.Tensor) -> Callable: return torch.nn.init.constant_(tensor, val) @@ -115,9 +121,7 @@ def compare_tensors(a: torch.Tensor, b: torch.Tensor) -> None: def ensure_divisibility(numerator: int, denominator: int) -> None: """Ensure that numerator is divisible by the denominator.""" - assert ( - numerator % denominator == 0 - ), f"{numerator} is not divisible by {denominator}" + assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" def divide(numerator: int, denominator: int) -> int: @@ -181,9 +185,7 @@ def validate_rng_states_func(get_rng_tracker: Callable) -> None: validate_ctx_manager(rng_tracker.fork) -def assert_viewless_tensor( - tensor: torch.Tensor, extra_msg: Optional[str] = None -) -> torch.Tensor: +def assert_viewless_tensor(tensor: torch.Tensor, extra_msg: Optional[str] = None) -> torch.Tensor: """Assert that a tensor is not a view (i.e., its '._base' field is not set).""" if isinstance(tensor, list): @@ -191,23 +193,21 @@ def assert_viewless_tensor( if not isinstance(tensor, torch.Tensor): return tensor assert tensor._base is None, ( - f"Ensure tensor._base is None before setting tensor.data or storing " - f"tensor to memory buffer. Otherwise, a memory leak will occur (and " + "Ensure tensor._base is None before setting tensor.data or storing " + "tensor to memory buffer. Otherwise, a memory leak will occur (and " f"likely accumulate over iterations). {extra_msg}" ) return tensor -def safely_set_viewless_tensor_data( - tensor: torch.Tensor, new_data_tensor: torch.Tensor -) -> None: +def safely_set_viewless_tensor_data(tensor: torch.Tensor, new_data_tensor: torch.Tensor) -> None: """Safely set tensor's '.data' field. Check first that the tensor is viewless (i.e., '._base' not set). If not, raise an exception. """ extra_msg = ( - f"FYI, tensor._base has shape " + "FYI, tensor._base has shape " f"{'--' if tensor._base is None else tensor._base.shape}," f"and new_data_tensor has shape {new_data_tensor.shape}." ) @@ -223,21 +223,13 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: """Check if tensor dimensions are supported for FP8 TN GEMM""" - return ( - tensor.dim() == 2 - and tensor.size(0) % 8 == 0 - and tensor.size(1) % 16 == 0 - ) + return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0 def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None: """Assert that tensor dimensions are supported for FP8 TN GEMM""" # single tensor check so it's clear which tensor is triggering the assertion - assert ( - tensor.dim() == 2 - and tensor.size(0) % 8 == 0 - and tensor.size(1) % 16 == 0 - ), ( + assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, ( "FP8 execution requires 2D input matrices with " "height divisible by 8 and width divisible by 16, " f"but got tensor with dims={list(tensor.size())}" @@ -246,7 +238,7 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None: def is_bf16_compatible() -> None: """Replaces torch.cuda.is_bf16_compatible() with an explicit - check on device compute capability to enforce sm_80 or higher. + check on device compute capability to enforce sm_80 or higher. """ return torch.cuda.get_device_capability()[0] >= 8