"tests/vscode:/vscode.git/clone" did not exist on "30cad990d09fce3c37951d09c6ec085c1216a313"
Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import ( ...@@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import (
_is_flash_attention_supported, _is_flash_attention_supported,
_is_fused_attention_supported, _is_fused_attention_supported,
_is_unfused_attention_supported, _is_unfused_attention_supported,
_run_dot_product_attention _run_dot_product_attention,
) )
pd.set_option("display.precision", 4) pd.set_option("display.precision", 4)
...@@ -28,7 +28,7 @@ ckpt_attn = False ...@@ -28,7 +28,7 @@ ckpt_attn = False
# workspace optimization path for cuDNN attention # workspace optimization path for cuDNN attention
workspace_opt = True workspace_opt = True
# QKV memory layout # QKV memory layout
qkv_layout = 'bshd_bshd_bshd' qkv_layout = "bshd_bshd_bshd"
# sliding window attention # sliding window attention
swa = False swa = False
# padding between sequences for qkv_format=thd # padding between sequences for qkv_format=thd
...@@ -38,16 +38,17 @@ is_training = True ...@@ -38,16 +38,17 @@ is_training = True
model_configs = { model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
} }
def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported): def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model] config = model_configs[model]
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2) tols = dict(atol=2.5e-2, rtol=2.5e-2)
else: else:
tols = dict(atol=5e-3, rtol=5e-3) tols = dict(atol=5e-3, rtol=5e-3)
...@@ -57,17 +58,31 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -57,17 +58,31 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
for i in range(warmup_iters): for i in range(warmup_iters):
if fused_attn_supported: if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
if flash_attn_supported: if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd): for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
...@@ -76,8 +91,15 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -76,8 +91,15 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
if fused_attn_supported: if fused_attn_supported:
for i in range(num_iters): for i in range(num_iters):
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0 fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0
...@@ -87,81 +109,113 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -87,81 +109,113 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
if flash_attn_supported: if flash_attn_supported:
for i in range(num_iters): for i in range(num_iters):
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0 flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0
df = pd.read_csv('times.csv') df = pd.read_csv("times.csv")
df = pd.concat([ df = pd.concat(
df, [
pd.DataFrame( df,
[[fused_attn_time*1e3/num_iters, 0, 0, 0, pd.DataFrame(
flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)], [
ignore_index=True [
) fused_attn_time * 1e3 / num_iters,
df.to_csv('times.csv',index=False) 0,
0,
0,
flash_attn_time * 1e3 / num_iters,
0,
0,
0,
0,
]
],
columns=df.columns,
),
],
ignore_index=True,
)
df.to_csv("times.csv", index=False)
torch.cuda.cudart().cudaProfilerStop() torch.cuda.cudart().cudaProfilerStop()
def parse_results(per_cudnn, per_flash, model): def parse_results(per_cudnn, per_flash, model):
filename = f'prof_{model}_cuda_gpu_trace.csv' filename = f"prof_{model}_cuda_gpu_trace.csv"
df = pd.read_csv(os.path.join('./',filename)) df = pd.read_csv(os.path.join("./", filename))
df_times = pd.read_csv('times.csv') df_times = pd.read_csv("times.csv")
row = len(df_times.index)-1 row = len(df_times.index) - 1
if per_cudnn > 0: if per_cudnn > 0:
t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy() t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn) t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0) t_cudnn_avg = np.average(t_cudnn_all, axis=0)
df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6 df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6
df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6 df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6 df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
if per_flash > 0: if per_flash > 0:
t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy() t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash) t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0) t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6 df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6 df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6 df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6
if per_cudnn > 0 and per_flash > 0: if per_cudnn > 0 and per_flash > 0:
df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \ df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \ df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"]
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] / df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"]
df_times.to_csv('times.csv',index=False) )
df_times.to_csv("times.csv", index=False)
def main(): def main():
times = pd.DataFrame( times = pd.DataFrame(
columns=[ columns=[
'FusedAttention Module', "FusedAttention Module",
'FusedAttention Kernels (fwd)', "FusedAttention Kernels (fwd)",
'FusedAttention Kernels (bwd)', "FusedAttention Kernels (bwd)",
'FusedAttention Kernels (fwd+bwd)', "FusedAttention Kernels (fwd+bwd)",
'FlashAttention Module', "FlashAttention Module",
'FlashAttention Kernels (fwd)', "FlashAttention Kernels (fwd)",
'FlashAttention Kernels (bwd)', "FlashAttention Kernels (bwd)",
'FlashAttention Kernels (fwd+bwd)', "FlashAttention Kernels (fwd+bwd)",
'Fused vs Flash Kernels Speedup (fwd+bwd)', "Fused vs Flash Kernels Speedup (fwd+bwd)",
]) ]
times.to_csv('times.csv',index=False) )
times.to_csv("times.csv", index=False)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(device_id) device_properties = torch.cuda.get_device_properties(device_id)
print(f"Device {device_id}: " print(
f"Device {device_id}: "
f"{device_properties.name} GPU, " f"{device_properties.name} GPU, "
f"sm{device_properties.major}{device_properties.minor} compute capability, " f"sm{device_properties.major}{device_properties.minor} compute capability, "
f"{device_properties.total_memory/1024**3:.1f}GB memory") f"{device_properties.total_memory/1024**3:.1f}GB memory"
)
for model in model_configs.keys(): for model in model_configs.keys():
config = model_configs[model] config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout, config,
dtype,
qkv_layout=qkv_layout,
) )
fused_attn_supported = fused_attn_supported and not swa fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config) flash_attn_supported = _is_flash_attention_supported(config)
print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' print(
f'{" and flash-attention" if flash_attn_supported else ""}...') f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...'
)
prof_cmd = [ prof_cmd = [
"nsys", "nsys",
...@@ -175,8 +229,8 @@ def main(): ...@@ -175,8 +229,8 @@ def main():
f""" "import benchmark_attention;""", f""" "import benchmark_attention;""",
f"""benchmark_attention.benchmark_dot_product_attention(""" f"""benchmark_attention.benchmark_dot_product_attention("""
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """, f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
] ]
prof_cmd = ' '.join(prof_cmd) prof_cmd = " ".join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
stats_cmd = [ stats_cmd = [
"nsys", "nsys",
...@@ -190,17 +244,17 @@ def main(): ...@@ -190,17 +244,17 @@ def main():
"--force-export=true", "--force-export=true",
f"--output=prof_{model}", f"--output=prof_{model}",
f"prof_{model}.nsys-rep", f"prof_{model}.nsys-rep",
] ]
if fused_attn_supported: if fused_attn_supported:
num_kernels_cudnn = 4 num_kernels_cudnn = 4
if config.attn_bias_type == 'post_scale_bias': if config.attn_bias_type == "post_scale_bias":
num_kernels_cudnn = num_kernels_cudnn+1 num_kernels_cudnn = num_kernels_cudnn + 1
if config.num_heads != config.num_gqa_groups: if config.num_heads != config.num_gqa_groups:
num_kernels_cudnn = num_kernels_cudnn+2 num_kernels_cudnn = num_kernels_cudnn + 2
else: else:
num_kernels_cudnn = 0 num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0 num_kernels_flash = 4 if flash_attn_supported else 0
stats_cmd = ' '.join(stats_cmd) stats_cmd = " ".join(stats_cmd)
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
parse_cmd = [ parse_cmd = [
"python", "python",
...@@ -208,18 +262,23 @@ def main(): ...@@ -208,18 +262,23 @@ def main():
f""" "import benchmark_attention;""", f""" "import benchmark_attention;""",
f"""benchmark_attention.parse_results(""" f"""benchmark_attention.parse_results("""
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """, f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
] ]
parse_cmd = ' '.join(parse_cmd) parse_cmd = " ".join(parse_cmd)
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
df_times = pd.read_csv('times.csv') df_times = pd.read_csv("times.csv")
df_times.index = list(model_configs.keys()) df_times.index = list(model_configs.keys())
a=df_times[['FusedAttention Kernels (fwd+bwd)', a = df_times[
'FlashAttention Kernels (fwd+bwd)', [
'Fused vs Flash Kernels Speedup (fwd+bwd)']] "FusedAttention Kernels (fwd+bwd)",
a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup'] "FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
]
]
a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
print() print()
print(a) print(a)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -64,6 +64,7 @@ class CMakeExtension(setuptools.Extension): ...@@ -64,6 +64,7 @@ class CMakeExtension(setuptools.Extension):
configure_command.append("-GNinja") configure_command.append("-GNinja")
import pybind11 import pybind11
pybind11_dir = Path(pybind11.__file__).resolve().parent pybind11_dir = Path(pybind11.__file__).resolve().parent
pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11" pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
configure_command.append(f"-Dpybind11_DIR={pybind11_dir}") configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
...@@ -130,6 +131,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): ...@@ -130,6 +131,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
else: else:
# Only during release sdist build. # Only during release sdist build.
import transformer_engine import transformer_engine
search_paths = list(Path(transformer_engine.__path__[0]).iterdir()) search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
del transformer_engine del transformer_engine
...@@ -142,8 +144,9 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): ...@@ -142,8 +144,9 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
# Figure out stub file path # Figure out stub file path
module_name = paddle_ext.name module_name = paddle_ext.name
assert module_name.endswith("_pd_"), \ assert module_name.endswith(
"Expected Paddle extension module to end with '_pd_'" "_pd_"
), "Expected Paddle extension module to end with '_pd_'"
stub_name = module_name[:-4] # remove '_pd_' stub_name = module_name[:-4] # remove '_pd_'
stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py") stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py")
Path(stub_path).parent.mkdir(exist_ok=True, parents=True) Path(stub_path).parent.mkdir(exist_ok=True, parents=True)
...@@ -158,6 +161,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): ...@@ -158,6 +161,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
# Write stub file # Write stub file
print(f"Writing Paddle stub for {lib_name} into file {stub_path}") print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
from paddle.utils.cpp_extension.extension_utils import custom_write_stub from paddle.utils.cpp_extension.extension_utils import custom_write_stub
custom_write_stub(lib_name, stub_path) custom_write_stub(lib_name, stub_path)
# Ensure that binaries are not in global package space. # Ensure that binaries are not in global package space.
...@@ -182,13 +186,14 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): ...@@ -182,13 +186,14 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
# extra_compile_args is a dict. # extra_compile_args is a dict.
for ext in self.extensions: for ext in self.extensions:
if isinstance(ext.extra_compile_args, dict): if isinstance(ext.extra_compile_args, dict):
for target in ['cxx', 'nvcc']: for target in ["cxx", "nvcc"]:
if target not in ext.extra_compile_args.keys(): if target not in ext.extra_compile_args.keys():
ext.extra_compile_args[target] = [] ext.extra_compile_args[target] = []
# Define new _compile method that redirects to NVCC for .cu and .cuh files. # Define new _compile method that redirects to NVCC for .cu and .cuh files.
original_compile_fn = self.compiler._compile original_compile_fn = self.compiler._compile
self.compiler.src_extensions += ['.cu', '.cuh'] self.compiler.src_extensions += [".cu", ".cuh"]
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None: def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
# Copy before we make any modifications. # Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs) cflags = copy.deepcopy(extra_postargs)
...@@ -197,31 +202,31 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): ...@@ -197,31 +202,31 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
_, nvcc_bin = cuda_path() _, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in ['.cu', '.cuh']: if os.path.splitext(src)[1] in [".cu", ".cuh"]:
self.compiler.set_executable('compiler_so', str(nvcc_bin)) self.compiler.set_executable("compiler_so", str(nvcc_bin))
if isinstance(cflags, dict): if isinstance(cflags, dict):
cflags = cflags['nvcc'] cflags = cflags["nvcc"]
# Add -fPIC if not already specified # Add -fPIC if not already specified
if not any('-fPIC' in flag for flag in cflags): if not any("-fPIC" in flag for flag in cflags):
cflags.extend(['--compiler-options', "'-fPIC'"]) cflags.extend(["--compiler-options", "'-fPIC'"])
# Forward unknown options # Forward unknown options
if not any('--forward-unknown-opts' in flag for flag in cflags): if not any("--forward-unknown-opts" in flag for flag in cflags):
cflags.append('--forward-unknown-opts') cflags.append("--forward-unknown-opts")
elif isinstance(cflags, dict): elif isinstance(cflags, dict):
cflags = cflags['cxx'] cflags = cflags["cxx"]
# Append -std=c++17 if not already in flags # Append -std=c++17 if not already in flags
if not any(flag.startswith('-std=') for flag in cflags): if not any(flag.startswith("-std=") for flag in cflags):
cflags.append('-std=c++17') cflags.append("-std=c++17")
return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts) return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)
finally: finally:
# Put the original compiler back in place. # Put the original compiler back in place.
self.compiler.set_executable('compiler_so', original_compiler) self.compiler.set_executable("compiler_so", original_compiler)
self.compiler._compile = _compile_fn self.compiler._compile = _compile_fn
......
...@@ -36,8 +36,8 @@ def setup_jax_extension( ...@@ -36,8 +36,8 @@ def setup_jax_extension(
] ]
# Compile flags # Compile flags
cxx_flags = [ "-O3" ] cxx_flags = ["-O3"]
nvcc_flags = [ "-O3" ] nvcc_flags = ["-O3"]
# Define TE/JAX as a Pybind11Extension # Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension from pybind11.setup_helpers import Pybind11Extension
...@@ -47,9 +47,9 @@ def setup_jax_extension( ...@@ -47,9 +47,9 @@ def setup_jax_extension(
def _add_cflags(self, flags: List[str]) -> None: def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict): if isinstance(self.extra_compile_args, dict):
cxx_flags = self.extra_compile_args.pop('cxx', []) cxx_flags = self.extra_compile_args.pop("cxx", [])
cxx_flags += flags cxx_flags += flags
self.extra_compile_args['cxx'] = cxx_flags self.extra_compile_args["cxx"] = cxx_flags
else: else:
self.extra_compile_args[:0] = flags self.extra_compile_args[:0] = flags
...@@ -57,8 +57,5 @@ def setup_jax_extension( ...@@ -57,8 +57,5 @@ def setup_jax_extension(
"transformer_engine_jax", "transformer_engine_jax",
sources=[str(path) for path in sources], sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs], include_dirs=[str(path) for path in include_dirs],
extra_compile_args={ extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags},
"cxx": cxx_flags,
"nvcc": nvcc_flags
},
) )
...@@ -76,11 +76,12 @@ def setup_pytorch_extension( ...@@ -76,11 +76,12 @@ def setup_pytorch_extension(
# Libraries -- PyTorch CUDAExtension links to libcudart.so but not to libcuda.so # Libraries -- PyTorch CUDAExtension links to libcudart.so but not to libcuda.so
cuda_home, _ = cuda_path() cuda_home, _ = cuda_path()
library_dirs = [ cuda_home / "compat" / "lib" ] library_dirs = [cuda_home / "compat" / "lib"]
libraries = [ "cuda" ] libraries = ["cuda"]
if os.getenv("UB_MPI_BOOTSTRAP"): if os.getenv("UB_MPI_BOOTSTRAP"):
assert os.getenv("MPI_HOME") is not None, \ assert (
"MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1" os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1"
mpi_home = Path(os.getenv("MPI_HOME")) mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include") include_dirs.append(mpi_home / "include")
cxx_flags.append("-DUB_MPI_BOOTSTRAP") cxx_flags.append("-DUB_MPI_BOOTSTRAP")
...@@ -95,12 +96,12 @@ def setup_pytorch_extension( ...@@ -95,12 +96,12 @@ def setup_pytorch_extension(
return CUDAExtension( return CUDAExtension(
name="transformer_engine_torch", name="transformer_engine_torch",
sources=[ str(src) for src in sources ], sources=[str(src) for src in sources],
include_dirs=[ str(inc) for inc in include_dirs ], include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={ extra_compile_args={
"cxx": cxx_flags, "cxx": cxx_flags,
"nvcc": nvcc_flags, "nvcc": nvcc_flags,
}, },
libraries=[ str(lib) for lib in libraries ], libraries=[str(lib) for lib in libraries],
library_dirs=[ str(lib_dir) for lib_dir in library_dirs ], library_dirs=[str(lib_dir) for lib_dir in library_dirs],
) )
...@@ -18,11 +18,12 @@ def te_version() -> str: ...@@ -18,11 +18,12 @@ def te_version() -> str:
root_path = Path(__file__).resolve().parent root_path = Path(__file__).resolve().parent
with open(root_path / "VERSION.txt", "r") as f: with open(root_path / "VERSION.txt", "r") as f:
version = f.readline().strip() version = f.readline().strip()
if (not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) and not bool(
and not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0")))): int(os.getenv("NVTE_RELEASE_BUILD", "0"))
):
try: try:
output = subprocess.run( output = subprocess.run(
["git", "rev-parse" , "--short", "HEAD"], ["git", "rev-parse", "--short", "HEAD"],
capture_output=True, capture_output=True,
cwd=root_path, cwd=root_path,
check=True, check=True,
......
...@@ -174,7 +174,7 @@ def cuda_version() -> Tuple[int, ...]: ...@@ -174,7 +174,7 @@ def cuda_version() -> Tuple[int, ...]:
universal_newlines=True, universal_newlines=True,
) )
match = re.search(r"release\s*([\d.]+)", output.stdout) match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split('.') version = match.group(1).split(".")
return tuple(int(v) for v in version) return tuple(int(v) for v in version)
...@@ -224,9 +224,7 @@ def get_frameworks() -> List[str]: ...@@ -224,9 +224,7 @@ def get_frameworks() -> List[str]:
_frameworks = [framework.lower() for framework in _frameworks] _frameworks = [framework.lower() for framework in _frameworks]
for framework in _frameworks: for framework in _frameworks:
if framework not in supported_frameworks: if framework not in supported_frameworks:
raise ValueError( raise ValueError(f"Transformer Engine does not support framework={framework}")
f"Transformer Engine does not support framework={framework}"
)
return _frameworks return _frameworks
...@@ -242,8 +240,8 @@ def package_files(directory): ...@@ -242,8 +240,8 @@ def package_files(directory):
def copy_common_headers(te_src, dst): def copy_common_headers(te_src, dst):
headers = te_src / "common" headers = te_src / "common"
for file_path in glob.glob(os.path.join(str(headers), "**", '*.h'), recursive=True): for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True):
new_path = os.path.join(dst, file_path[len(str(te_src)) + 1:]) new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :])
Path(new_path).parent.mkdir(exist_ok=True, parents=True) Path(new_path).parent.mkdir(exist_ok=True, parents=True)
shutil.copy(file_path, new_path) shutil.copy(file_path, new_path)
...@@ -251,9 +249,10 @@ def copy_common_headers(te_src, dst): ...@@ -251,9 +249,10 @@ def copy_common_headers(te_src, dst):
def install_and_import(package): def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals.""" """Install a package via pip (if not already installed) and import into globals."""
import importlib import importlib
try: try:
importlib.import_module(package) importlib.import_module(package)
except ImportError: except ImportError:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package]) subprocess.check_call([sys.executable, "-m", "pip", "install", package])
finally: finally:
globals()[package] = importlib.import_module(package) globals()[package] = importlib.import_module(package)
...@@ -28,21 +28,26 @@ if current_year == release_year: ...@@ -28,21 +28,26 @@ if current_year == release_year:
else: else:
copyright_year = str(release_year) + "-" + str(current_year) copyright_year = str(release_year) + "-" + str(current_year)
project = u'Transformer Engine' project = "Transformer Engine"
copyright = u'{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.'.format(copyright_year) copyright = "{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.".format(copyright_year)
author = u'NVIDIA CORPORATION & AFFILIATES' author = "NVIDIA CORPORATION & AFFILIATES"
git_sha = os.getenv("GIT_SHA") git_sha = os.getenv("GIT_SHA")
if not git_sha: if not git_sha:
try: try:
git_sha = subprocess.check_output(["git", "log", "--pretty=format:'%h'", "-n1"]).decode('ascii').replace("'","").strip() git_sha = (
subprocess.check_output(["git", "log", "--pretty=format:'%h'", "-n1"])
.decode("ascii")
.replace("'", "")
.strip()
)
except: except:
git_sha = u'0000000' git_sha = "0000000"
git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha
version = str(te_version + u"-" + git_sha) version = str(te_version + "-" + git_sha)
release = te_version release = te_version
# hack: version is used for html creation, so put the version picker # hack: version is used for html creation, so put the version picker
...@@ -51,58 +56,60 @@ option_on = " selected" ...@@ -51,58 +56,60 @@ option_on = " selected"
option_off = "" option_off = ""
release_opt = option_on release_opt = option_on
option_nr = 0 option_nr = 0
version = version + """<br/> version = (
version
+ """<br/>
Version select: <select onChange="window.location.href = this.value" onFocus="this.selectedIndex = {0}"> Version select: <select onChange="window.location.href = this.value" onFocus="this.selectedIndex = {0}">
<option value="https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html"{1}>Current release</option> <option value="https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html"{1}>Current release</option>
<option value="https://docs.nvidia.com/deeplearning/transformer-engine/documentation-archive.html">Older releases</option> <option value="https://docs.nvidia.com/deeplearning/transformer-engine/documentation-archive.html">Older releases</option>
</select>""".format(option_nr, release_opt) </select>""".format(
option_nr, release_opt
)
)
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.mathjax', "sphinx.ext.mathjax",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'sphinx.ext.ifconfig', "sphinx.ext.ifconfig",
'nbsphinx', "nbsphinx",
'breathe', "breathe",
'autoapi.extension', "autoapi.extension",
] ]
templates_path = ['_templates'] templates_path = ["_templates"]
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
source_suffix = '.rst' source_suffix = ".rst"
master_doc = 'index' master_doc = "index"
pygments_style = 'sphinx'
pygments_style = "sphinx"
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme' html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
html_static_path = ['_static'] html_static_path = ["_static"]
html_show_sphinx = False html_show_sphinx = False
html_css_files = [ html_css_files = [
'css/nvidia_font.css', "css/nvidia_font.css",
'css/nvidia_footer.css', "css/nvidia_footer.css",
] ]
html_theme_options = { html_theme_options = {"display_version": True, "collapse_navigation": False, "logo_only": False}
'display_version': True,
'collapse_navigation': False,
'logo_only': False
}
napoleon_custom_sections = [('Parallelism parameters', 'params_style'), napoleon_custom_sections = [
('Optimization parameters', 'params_style'), ("Parallelism parameters", "params_style"),
('Values', 'params_style')] ("Optimization parameters", "params_style"),
("Values", "params_style"),
]
breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")} breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")}
breathe_default_project = "TransformerEngine" breathe_default_project = "TransformerEngine"
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import os import os
import torch import torch
from typing import Tuple from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
...@@ -18,87 +18,105 @@ _cuda_rng_state = torch.cuda.get_rng_state() ...@@ -18,87 +18,105 @@ _cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None: def reset_rng_states() -> None:
"""Revert back to initial RNG state""" """Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state) torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state) _set_cuda_rng_state(_cuda_rng_state)
def _run_dot_product_attention( def _run_dot_product_attention(
dtype: torch.dtype, dtype: torch.dtype,
config: ModelConfig, config: ModelConfig,
qkv_layout: str, qkv_layout: str,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass""" """Run DotProductAttention module with one forward pass and one backward pass"""
reset_rng_states() reset_rng_states()
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, seqlens_q = torch.full(
dtype=torch.int32, device="cuda") [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, )
dtype=torch.int32, device="cuda") seqlens_kv = torch.full(
inp = torch.randn([config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim], [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
dtype=dtype, device="cuda") )
q = inp[:,:,0,:,:] inp = torch.randn(
k = inp[:,:,1,:,:] [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
v = inp[:,:,2,:,:] dtype=dtype,
device="cuda",
)
q = inp[:, :, 0, :, :]
k = inp[:, :, 1, :, :]
v = inp[:, :, 2, :, :]
q.requires_grad = True q.requires_grad = True
k.requires_grad = True k.requires_grad = True
v.requires_grad = True v.requires_grad = True
out_grad = torch.randn([config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim], out_grad = torch.randn(
dtype=dtype, device="cuda") [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
dtype=dtype,
device="cuda",
)
# Create attention mask / bias # Create attention mask / bias
attention_mask = None attention_mask = None
bias = None bias = None
if config.attn_mask_type == "arbitrary": if config.attn_mask_type == "arbitrary":
attention_mask = torch.randint(-10,10, attention_mask = torch.randint(
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to( -10,
dtype=torch.bool, device="cuda") 10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv],
).to(dtype=torch.bool, device="cuda")
if config.attn_bias_type == "post_scale_bias": if config.attn_bias_type == "post_scale_bias":
# convert mask to bias # convert mask to bias
attention_mask = torch.randint(-10,10, attention_mask = torch.randint(
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to( -10,
dtype=torch.bool, device="cuda") 10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv],
).to(dtype=torch.bool, device="cuda")
bias = attention_mask.clone() bias = attention_mask.clone()
neginf = -2**50 if dtype == torch.bfloat16 else -2**15 neginf = -(2**50) if dtype == torch.bfloat16 else -(2**15)
bias = torch.where(bias==0, 0, neginf).to(dtype=dtype, device='cuda') bias = torch.where(bias == 0, 0, neginf).to(dtype=dtype, device="cuda")
bias.requires_grad = False bias.requires_grad = False
attention_mask = None attention_mask = None
block = ( block = DotProductAttention(
DotProductAttention( config.num_heads,
config.num_heads, config.head_dim,
config.head_dim, num_gqa_groups=config.num_gqa_groups,
num_gqa_groups=config.num_gqa_groups, qkv_format="bshd",
qkv_format='bshd', attention_dropout=config.dropout_p,
attention_dropout=config.dropout_p, sequence_parallel=False,
sequence_parallel=False, tp_size=1,
tp_size=1, get_rng_state_tracker=None,
get_rng_state_tracker=None, tp_group=None,
tp_group=None, layer_number=1,
layer_number=1, ).to(dtype=dtype, device="cuda")
).to(dtype=dtype, device="cuda")
)
# Run a forward and backward pass # Run a forward and backward pass
out = None out = None
if config.attn_mask_type == "arbitrary": if config.attn_mask_type == "arbitrary":
out = block(q, k, v, out = block(
attention_mask=attention_mask, # attention_mask q,
qkv_format='bshd', k,
attn_mask_type=config.attn_mask_type, # 'arbitrary' v,
core_attention_bias_type=config.attn_bias_type, # 'no_bias' attention_mask=attention_mask, # attention_mask
core_attention_bias=bias, # None qkv_format="bshd",
) attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
)
out.backward(out_grad) out.backward(out_grad)
if config.attn_bias_type == "post_scale_bias": if config.attn_bias_type == "post_scale_bias":
out = block(q, k, v, out = block(
attention_mask=attention_mask, # None q,
qkv_format='bshd', k,
attn_mask_type=config.attn_mask_type, # no_mask v,
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias' attention_mask=attention_mask, # None
core_attention_bias=bias, # bias qkv_format="bshd",
) attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
)
out.backward(out_grad) out.backward(out_grad)
return out, (q.grad, k.grad, v.grad) return out, (q.grad, k.grad, v.grad)
...@@ -107,19 +125,19 @@ def _run_dot_product_attention( ...@@ -107,19 +125,19 @@ def _run_dot_product_attention(
dtype = torch.bfloat16 dtype = torch.bfloat16
model_configs = { model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"), "test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"),
"test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), "test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
} }
print('Run with post_scale_bias:') print("Run with post_scale_bias:")
config = model_configs["test_bias"] config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd') fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
print('Run with arbitrary mask:') print("Run with arbitrary mask:")
config = model_configs["test_mask"] config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd') unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2) torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3): for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2) torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
print('Test passed!') print("Test passed!")
...@@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import ( ...@@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import (
_is_flash_attention_supported, _is_flash_attention_supported,
_is_fused_attention_supported, _is_fused_attention_supported,
_is_unfused_attention_supported, _is_unfused_attention_supported,
_run_dot_product_attention _run_dot_product_attention,
) )
# data type # data type
...@@ -26,7 +26,7 @@ ckpt_attn = False ...@@ -26,7 +26,7 @@ ckpt_attn = False
# workspace optimization path for cuDNN attention # workspace optimization path for cuDNN attention
workspace_opt = True workspace_opt = True
# QKV memory layout # QKV memory layout
qkv_layout = 'bshd_bshd_bshd' qkv_layout = "bshd_bshd_bshd"
# sliding window attention # sliding window attention
swa = False swa = False
# padding between sequences for qkv_format=thd # padding between sequences for qkv_format=thd
...@@ -36,12 +36,13 @@ is_training = True ...@@ -36,12 +36,13 @@ is_training = True
model_configs = { model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
} }
def example_attention(model, fused_attn_supported, flash_attn_supported): def example_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model] config = model_configs[model]
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
...@@ -51,40 +52,58 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): ...@@ -51,40 +52,58 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
if fused_attn_supported: if fused_attn_supported:
print() print()
print('Run cuDNN attention...') print("Run cuDNN attention...")
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
if flash_attn_supported: if flash_attn_supported:
print() print()
print('Run flash-attention...') print("Run flash-attention...")
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd): for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
print() print()
print('Test passed.') print("Test passed.")
def main(): def main():
models = ['test_0'] models = ["test_0"]
for model in models: for model in models:
config = model_configs[model] config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout, config,
dtype,
qkv_layout=qkv_layout,
) )
fused_attn_supported = fused_attn_supported and not swa fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config) flash_attn_supported = _is_flash_attention_supported(config)
example_attention(model, fused_attn_supported, flash_attn_supported) example_attention(model, fused_attn_supported, flash_attn_supported)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -8,14 +8,15 @@ import torch ...@@ -8,14 +8,15 @@ import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type
def speedometer( def speedometer(
module: torch.nn.Module, module: torch.nn.Module,
input: torch.Tensor, input: torch.Tensor,
output_grad: torch.Tensor, output_grad: torch.Tensor,
forward_kwargs: dict = {}, forward_kwargs: dict = {},
fp8_autocast_kwargs: Optional[dict] = None, fp8_autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50, timing_iters: int = 50,
warmup_iters: int = 50, warmup_iters: int = 50,
) -> None: ) -> None:
"""Measure average run time for a PyTorch module """Measure average run time for a PyTorch module
...@@ -24,7 +25,7 @@ def speedometer( ...@@ -24,7 +25,7 @@ def speedometer(
start = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
if fp8_autocast_kwargs is None: if fp8_autocast_kwargs is None:
fp8_autocast_kwargs = { "enabled": False } fp8_autocast_kwargs = {"enabled": False}
# Warmup runs # Warmup runs
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -51,11 +52,12 @@ class DotProductAttention(torch.nn.Module): ...@@ -51,11 +52,12 @@ class DotProductAttention(torch.nn.Module):
Built with plain PyTorch modules. Built with plain PyTorch modules.
""" """
def __init__( def __init__(
self, self,
num_attention_heads: int, num_attention_heads: int,
kv_channels: int, kv_channels: int,
attention_dropout: float, attention_dropout: float,
) -> None: ) -> None:
super().__init__() super().__init__()
self.projection_size = kv_channels * num_attention_heads self.projection_size = kv_channels * num_attention_heads
...@@ -63,21 +65,17 @@ class DotProductAttention(torch.nn.Module): ...@@ -63,21 +65,17 @@ class DotProductAttention(torch.nn.Module):
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.dropout = torch.nn.Dropout(attention_dropout) self.dropout = torch.nn.Dropout(attention_dropout)
def masked_softmax( def masked_softmax(self, inp: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
self,
inp: torch.Tensor,
mask: Optional[torch.Tensor]
) -> torch.Tensor:
if mask is not None: if mask is not None:
inp.masked_fill_(mask, -10000.0) inp.masked_fill_(mask, -10000.0)
return torch.nn.Softmax(dim=-1)(inp) return torch.nn.Softmax(dim=-1)(inp)
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
b = query.size(1) b = query.size(1)
np = query.size(2) np = query.size(2)
...@@ -90,7 +88,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -90,7 +88,9 @@ class DotProductAttention(torch.nn.Module):
# [sk, b, np, hn] -> [sk, b * np, hn] # [sk, b, np, hn] -> [sk, b * np, hn]
key = key.view(sk, b * np, -1) key = key.view(sk, b * np, -1)
bmm1 = torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor bmm1 = (
torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor
)
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = bmm1.view(b, np, sq, sk) attention_scores = bmm1.view(b, np, sq, sk)
...@@ -126,10 +126,11 @@ class BasicMLP(torch.nn.Module): ...@@ -126,10 +126,11 @@ class BasicMLP(torch.nn.Module):
Built with plain PyTorch modules. Built with plain PyTorch modules.
""" """
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
ffn_hidden_size: int, ffn_hidden_size: int,
) -> None: ) -> None:
super().__init__() super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True) self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
...@@ -137,7 +138,7 @@ class BasicMLP(torch.nn.Module): ...@@ -137,7 +138,7 @@ class BasicMLP(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x) x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate='tanh') x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x) x = self.linear2(x)
return x return x
...@@ -148,7 +149,7 @@ def share_parameters_with_basic_te_model(te_model, basic_model): ...@@ -148,7 +149,7 @@ def share_parameters_with_basic_te_model(te_model, basic_model):
Parameter values are copied from pure PyTorch implementation. Parameter values are copied from pure PyTorch implementation.
""" """
te_model.ln1.weight= basic_model.ln1.weight te_model.ln1.weight = basic_model.ln1.weight
te_model.ln1.bias = basic_model.ln1.bias te_model.ln1.bias = basic_model.ln1.bias
te_model.qkv_projection.weight = basic_model.qkv_projection.weight te_model.qkv_projection.weight = basic_model.qkv_projection.weight
te_model.qkv_projection.bias = basic_model.qkv_projection.bias te_model.qkv_projection.bias = basic_model.qkv_projection.bias
...@@ -202,14 +203,15 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model): ...@@ -202,14 +203,15 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):
te_model.layernorm_mlp.fc2_bias = basic_model.mlp.linear2.bias te_model.layernorm_mlp.fc2_bias = basic_model.mlp.linear2.bias
def cast_to_representable(inp, scale = 1., fp8_format='e4m3'): def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"):
import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
fp8_type = tex.DType.kFloat8E4M3 if fp8_format == 'e4m3' else tex.DType.kFloat8E5M2
fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2
input_type = TE_DType[inp.dtype] input_type = TE_DType[inp.dtype]
meta = tex.FP8TensorMeta() meta = tex.FP8TensorMeta()
meta.scale = torch.ones(1,dtype=torch.float32, device="cuda") * scale meta.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale
meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda") meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
ret = texcpp.cast_to_fp8(inp, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type) ret = texcpp.cast_to_fp8(inp, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type)
......
...@@ -15,11 +15,17 @@ from transformer_engine.pytorch.attention import RotaryPositionEmbedding ...@@ -15,11 +15,17 @@ from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.fp8 import fp8_model_init
import transformers import transformers
from transformers.models.llama.modeling_llama import LlamaModel, LlamaForCausalLM, LlamaRMSNorm, LlamaConfig from transformers.models.llama.modeling_llama import (
LlamaModel,
LlamaForCausalLM,
LlamaRMSNorm,
LlamaConfig,
)
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
from transformers.utils import WEIGHTS_INDEX_NAME from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files from transformers.utils.hub import get_checkpoint_shard_files
@contextmanager @contextmanager
def replace_decoder(te_decoder_cls): def replace_decoder(te_decoder_cls):
""" """
...@@ -43,6 +49,7 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer): ...@@ -43,6 +49,7 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
args: positional args (for compatibility with `LlamaDecoderLayer`) args: positional args (for compatibility with `LlamaDecoderLayer`)
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`) kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
""" """
def __init__(self, config, *args, **kwargs): def __init__(self, config, *args, **kwargs):
super().__init__( super().__init__(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -56,22 +63,22 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer): ...@@ -56,22 +63,22 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
normalization="RMSNorm", normalization="RMSNorm",
activation="swiglu", activation="swiglu",
attn_input_format="bshd", attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads num_gqa_groups=config.num_key_value_heads,
) )
te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads) te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
def forward(self, def forward(self, hidden_states, *args, attention_mask, **kwargs):
hidden_states,
*args,
attention_mask,
**kwargs):
""" """
Custom forward to make sure we only pass relevant arguments to the Custom forward to make sure we only pass relevant arguments to the
forward pass of the `TransformerLayer`. Also, make sure the output forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`. format matches the output of the HF's `LlamaDecoderLayer`.
""" """
return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),) return (
super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
),
)
class TELlamaForCausalLM: class TELlamaForCausalLM:
...@@ -95,21 +102,29 @@ class TELlamaForCausalLM: ...@@ -95,21 +102,29 @@ class TELlamaForCausalLM:
Custom method adapted from `from_pretrained` method in HuggingFace Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
""" """
vanilla_model = cls(config).to(kwargs['torch_dtype']) vanilla_model = cls(config).to(kwargs["torch_dtype"])
is_local = os.path.isdir(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path)
subfolder = "" subfolder = ""
variant = None variant = None
if os.path.isfile( if os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant)) os.path.join(
): pretrained_model_name_or_path,
subfolder,
_add_variant("model.safetensors.index.json", variant),
)
):
# Load from a sharded PyTorch checkpoint # Load from a sharded PyTorch checkpoint
archive_file = os.path.join( archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant) pretrained_model_name_or_path,
subfolder,
_add_variant("model.safetensors.index.json", variant),
) )
is_sharded = True is_sharded = True
elif os.path.isfile( elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) os.path.join(
): pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
):
# Load from a sharded PyTorch checkpoint # Load from a sharded PyTorch checkpoint
archive_file = os.path.join( archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
...@@ -118,10 +133,9 @@ class TELlamaForCausalLM: ...@@ -118,10 +133,9 @@ class TELlamaForCausalLM:
else: else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path, pretrained_model_name_or_path,
archive_file, archive_file,
) )
# If the checkpoint is not sharded, it's a trivial sharding case # If the checkpoint is not sharded, it's a trivial sharding case
...@@ -142,48 +156,63 @@ class TELlamaForCausalLM: ...@@ -142,48 +156,63 @@ class TELlamaForCausalLM:
return vanilla_model return vanilla_model
def replace_params(hf_state_dict, te_state_dict, config): def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update # collect all layer prefixes to update
all_layer_prefixes = set() all_layer_prefixes = set()
for param_key in hf_state_dict.keys(): for param_key in hf_state_dict.keys():
layer_prefix_pat = 'model.layers.\d+.' layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key) m = re.match(layer_prefix_pat, param_key)
if m is not None: if m is not None:
all_layer_prefixes.add(m.group()) all_layer_prefixes.add(m.group())
for layer_prefix in all_layer_prefixes: for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the # When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in HF model # copy if the corresponding layer doesn't exist in HF model
if layer_prefix + 'input_layernorm.weight' in hf_state_dict: if layer_prefix + "input_layernorm.weight" in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:] te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[
:
] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
)
if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:] te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
)
if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:] te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
)
if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:] te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[
layer_prefix + "self_attn.o_proj.weight"
].data[:]
if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict: if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:] te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[
layer_prefix + "post_attention_layernorm.weight"
].data[:]
if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]
# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately. # load them separately.
if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \ te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data : config.intermediate_size
] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \ if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
config.intermediate_size :
if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]
return all_layer_prefixes if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
\ No newline at end of file te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[
layer_prefix + "mlp.down_proj.weight"
].data[:]
return all_layer_prefixes
...@@ -10,29 +10,36 @@ import torch ...@@ -10,29 +10,36 @@ import torch
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, AutoConfig from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
get_linear_schedule_with_warmup,
AutoConfig,
)
from transformers import DataCollatorForLanguageModeling from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset from datasets import load_dataset
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils.dataclasses import FP8RecipeKwargs from accelerate.utils.dataclasses import FP8RecipeKwargs
class HyperParameters: class HyperParameters:
def __init__(self): def __init__(self):
self.mixed_precision = "bf16" self.mixed_precision = "bf16"
#self.model_name = "" # <== Add model weight location here # self.model_name = "" # <== Add model weight location here
self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text" self.dataset_text_field = "text"
self.learning_rate = 1.41e-5 self.learning_rate = 1.41e-5
self.batch_size = 8 self.batch_size = 8
self.max_seq_length = 256 self.max_seq_length = 256
self.gradient_accumulation_steps = 1 self.gradient_accumulation_steps = 1
self.num_warmup_steps=5 self.num_warmup_steps = 5
self.num_training_steps=10 self.num_training_steps = 10
hyperparams = HyperParameters() hyperparams = HyperParameters()
def get_dataloaders(accelerator:Accelerator, hyperparams):
def get_dataloaders(accelerator: Accelerator, hyperparams):
dataset = load_dataset(hyperparams.dataset_name, split="train") dataset = load_dataset(hyperparams.dataset_name, split="train")
tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
if getattr(tokenizer, "pad_token", None) is None: if getattr(tokenizer, "pad_token", None) is None:
...@@ -45,16 +52,12 @@ def get_dataloaders(accelerator:Accelerator, hyperparams): ...@@ -45,16 +52,12 @@ def get_dataloaders(accelerator:Accelerator, hyperparams):
padding=False, padding=False,
max_length=hyperparams.max_seq_length, max_length=hyperparams.max_seq_length,
return_overflowing_tokens=False, return_overflowing_tokens=False,
return_length=False return_length=False,
) )
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
with accelerator.main_process_first(): with accelerator.main_process_first():
dataset = dataset.map( dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
tokenize,
batched=True,
remove_columns=dataset.column_names
)
# Simply pad to the multiple of 16 for both FP8 and BF16 precision # Simply pad to the multiple of 16 for both FP8 and BF16 precision
pad_to_multiple_of = 16 pad_to_multiple_of = 16
...@@ -72,6 +75,7 @@ def get_dataloaders(accelerator:Accelerator, hyperparams): ...@@ -72,6 +75,7 @@ def get_dataloaders(accelerator:Accelerator, hyperparams):
train_dataloader = DataLoader(dataset, **dataloader_params) train_dataloader = DataLoader(dataset, **dataloader_params)
return train_dataloader return train_dataloader
def init_baseline_model(hyperparams): def init_baseline_model(hyperparams):
# Init the model # Init the model
config = AutoConfig.from_pretrained(hyperparams.model_name) config = AutoConfig.from_pretrained(hyperparams.model_name)
...@@ -84,42 +88,47 @@ def init_baseline_model(hyperparams): ...@@ -84,42 +88,47 @@ def init_baseline_model(hyperparams):
) )
model = model.cuda() model = model.cuda()
# Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison # Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison
model.config.use_cache=False model.config.use_cache = False
return model return model
def init_te_llama_model(hyperparams): def init_te_llama_model(hyperparams):
# Init the model # Init the model
from te_llama import TELlamaForCausalLM from te_llama import TELlamaForCausalLM
config = AutoConfig.from_pretrained(hyperparams.model_name) config = AutoConfig.from_pretrained(hyperparams.model_name)
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local( model = TELlamaForCausalLM.from_pretrained_local(
hyperparams.model_name, hyperparams.model_name,
config=config, config=config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
) )
model = model.cuda() model = model.cuda()
# Needed for the cases when using TELlamaForCausalLM # Needed for the cases when using TELlamaForCausalLM
model.config.use_cache=False model.config.use_cache = False
return model return model
def wrap_with_accelerator(model, hyperparams): def wrap_with_accelerator(model, hyperparams):
# Create FP8 kwarg handler if required # Create FP8 kwarg handler if required
fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None fp8_kwarg_handler = (
[FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None
)
# Init HF accelerator that's used for training # Init HF accelerator that's used for training
accelerator = Accelerator( accelerator = Accelerator(
log_with="wandb", log_with="wandb",
gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,
mixed_precision=hyperparams.mixed_precision, mixed_precision=hyperparams.mixed_precision,
kwargs_handlers=fp8_kwarg_handler kwargs_handlers=fp8_kwarg_handler,
) )
#accelerator.print(f'State: {accelerator.state}') # accelerator.print(f'State: {accelerator.state}')
train_dataloader = get_dataloaders(accelerator, hyperparams) train_dataloader = get_dataloaders(accelerator, hyperparams)
# Wrap model, optimizer/scheduler, dataloaders in accelerate # Wrap model, optimizer/scheduler, dataloaders in accelerate
optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True) optimizer = AdamW(params=model.parameters(), lr=hyperparams.learning_rate, fused=True)
lr_scheduler = get_linear_schedule_with_warmup( lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=100, num_warmup_steps=100,
...@@ -131,6 +140,7 @@ def wrap_with_accelerator(model, hyperparams): ...@@ -131,6 +140,7 @@ def wrap_with_accelerator(model, hyperparams):
return accelerator, model, optimizer, train_dataloader, lr_scheduler return accelerator, model, optimizer, train_dataloader, lr_scheduler
def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler): def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler):
model.train() model.train()
total_loss = 0 total_loss = 0
...@@ -170,7 +180,11 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, ...@@ -170,7 +180,11 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer,
end.record() end.record()
accelerator.end_training() accelerator.end_training()
print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds") print(
f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step:"
f" {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds"
)
def restart_jupyter_notebook(): def restart_jupyter_notebook():
# Try restarting the Jupyter kernel # Try restarting the Jupyter kernel
...@@ -179,18 +193,23 @@ def restart_jupyter_notebook(): ...@@ -179,18 +193,23 @@ def restart_jupyter_notebook():
# Check whether the device memory has been flushed # Check whether the device memory has been flushed
if torch.cuda.memory_allocated() != 0: if torch.cuda.memory_allocated() != 0:
import warnings import warnings
warnings.warn("The device memory hasn't been flushed, trying with a second method!") warnings.warn("The device memory hasn't been flushed, trying with a second method!")
# Try restarting the Jupyter kernel another way # Try restarting the Jupyter kernel another way
# Restart the kernel # Restart the kernel
from IPython.core.display import HTML from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>") HTML("<script>Jupyter.notebook.kernel.restart()</script>")
if torch.cuda.memory_allocated() != 0: if torch.cuda.memory_allocated() != 0:
print("The device memory hasn't been flushed, try manually restarting the Jupyter kernel!") print(
"The device memory hasn't been flushed, try manually restarting the Jupyter kernel!"
)
# Suppress the warnings # Suppress the warnings
if not sys.warnoptions: if not sys.warnoptions:
import warnings import warnings
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
torch.set_warn_always(False) torch.set_warn_always(False)
...@@ -22,18 +22,19 @@ from jax.experimental.pjit import pjit ...@@ -22,18 +22,19 @@ from jax.experimental.pjit import pjit
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
DEVICE_DP_AXIS = 'data' DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = 'model' DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = 'my_broadcast_axis' NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = 'my_tp_axis' NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = 'params' PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + '_axes' PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = 'dropout' DROPOUT_KEY = "dropout"
INPUT_KEY = 'input_rng' INPUT_KEY = "input_rng"
class Net(nn.Module): class Net(nn.Module):
"""NLP Encoder""" """NLP Encoder"""
num_embed: int num_embed: int
enable_seq_paral: bool enable_seq_paral: bool
...@@ -41,36 +42,43 @@ class Net(nn.Module): ...@@ -41,36 +42,43 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te_flax.TransformerLayer, te_Encoder = partial(
hidden_size=256, te_flax.TransformerLayer,
mlp_hidden_size=1024, hidden_size=256,
num_attention_heads=8, mlp_hidden_size=1024,
hidden_dropout=0.1, num_attention_heads=8,
attention_dropout=0.1, hidden_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, attention_dropout=0.1,
layer_type=te_flax.TransformerLayerType.ENCODER, dropout_rng_name=DROPOUT_KEY,
self_attn_mask_type='padding', layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, self_attn_mask_type="padding",
enable_sequence_parallel=self.enable_seq_paral, enable_relative_embedding=False,
dtype=jnp.bfloat16) enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral: if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device. # Trigger all-gather to collect a complete tensor alone seqence on each device.
x = jax.lax.with_sharding_constraint(x, x = jax.lax.with_sharding_constraint(
jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)) x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
)
x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), x = te_flax.DenseGeneral(
bias_axes=(NAMED_TP_AXIS,), features=256,
dtype=jnp.bfloat16)(x) kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), )(x)
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x return x
...@@ -98,20 +106,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): ...@@ -98,20 +106,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch.""" """Train for a single epoch."""
train_ds_size = len(train_ds['sentence']) train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size)) perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = [] epoch_loss = []
epoch_accuracy = [] epoch_accuracy = []
for perm in perms: for perm in perms:
batch_inputs = train_ds['sentence'][perm, ...] batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds['mask'][perm, ...] batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds['label'][perm, ...] batch_labels = train_ds["label"][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, state, loss, accuracy, var_collect = train_fn(
batch_labels, var_collect, rngs) state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
epoch_loss.append(loss) epoch_loss.append(loss)
epoch_accuracy.append(accuracy) epoch_accuracy.append(accuracy)
...@@ -137,7 +146,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -137,7 +146,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop.""" """Evaluation loop."""
test_ds_size = len(test_ds['sentence']) test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size valid_size = num_steps * batch_size
all_loss = [] all_loss = []
...@@ -145,9 +154,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): ...@@ -145,9 +154,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
for batch_start in range(0, valid_size, batch_size): for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size batch_end = batch_start + batch_size
batch_inputs = test_ds['sentence'][batch_start:batch_end] batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds['mask'][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss) all_loss.append(loss)
all_accuracy.append(accuracy) all_accuracy.append(accuracy)
...@@ -159,12 +168,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): ...@@ -159,12 +168,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len): def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers.""" """Convert tokens to numbers."""
nltk.download('punkt') nltk.download("punkt")
dataset_size = len(dataset['sentence']) dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']): for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence) tokens = nltk.word_tokenize(sentence)
tensor = output[j] tensor = output[j]
...@@ -184,9 +193,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -184,9 +193,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
new_dataset = { new_dataset = {
'sentence': output, "sentence": output,
'label': dataset['label'].astype(np.float32), "label": dataset["label"].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
} }
return new_dataset, vocab, word_id return new_dataset, vocab, word_id
...@@ -196,12 +205,12 @@ def get_datasets(max_seq_len): ...@@ -196,12 +205,12 @@ def get_datasets(max_seq_len):
vocab = {} vocab = {}
word_id = 0 word_id = 0
train_ds = load_dataset('glue', 'cola', split='train') train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type='np') train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset('glue', 'cola', split='validation') test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type='np') test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
...@@ -210,7 +219,8 @@ def check_fp8(state, var_collect, inputs, masks, labels): ...@@ -210,7 +219,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str( assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def get_params_pspec(sharding_rules, abs_var_collect): def get_params_pspec(sharding_rules, abs_var_collect):
...@@ -255,8 +265,9 @@ def train_and_evaluate(args): ...@@ -255,8 +265,9 @@ def train_and_evaluate(args):
num_gpu_tp = 1 num_gpu_tp = 1
assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}" assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
assert args.test_batch_size % num_gpu_dp == 0, \ assert (
f"Test batch size needs to be multiple of {num_gpu_dp}" args.test_batch_size % num_gpu_dp == 0
), f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
...@@ -270,9 +281,9 @@ def train_and_evaluate(args): ...@@ -270,9 +281,9 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8, with te.fp8_autocast(
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
None)): ):
encoder = Net(num_embed, args.enable_sp) encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
...@@ -285,18 +296,21 @@ def train_and_evaluate(args): ...@@ -285,18 +296,21 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec) in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = {key: params_pspec if key is PARAMS_KEY else None \ out_shardings = {
for key in abs_var_collect} key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks) var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply, state = train_state.TrainState.create(
params=params, apply_fn=encoder.apply, params=params, tx=optimizer
tx=optimizer) )
state_pspec = get_state_pspec(state, params_pspec) state_pspec = get_state_pspec(state, params_pspec)
labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,) labels_pspec = jax.sharding.PartitionSpec(
DEVICE_DP_AXIS,
)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None) out_shardings = (state_pspec, None, None, None)
...@@ -323,16 +337,20 @@ def train_and_evaluate(args): ...@@ -323,16 +337,20 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step) state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, test_loss, test_accuracy = eval_model(
var_collect, pjit_eval_step) state, test_ds, args.test_batch_size, var_collect, pjit_eval_step
)
print(f"Epoch: {epoch:>2} " print(
f"Train Loss: {train_loss:.6f} " f"Epoch: {epoch:>2} "
f"Train Accuracy: {train_accuracy:.6f} " f"Train Loss: {train_loss:.6f} "
f"Test Loss: {test_loss:.6f} " f"Train Accuracy: {train_accuracy:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ") f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy] return [train_loss, train_accuracy, test_loss, test_accuracy]
...@@ -382,14 +400,15 @@ def encoder_parser(args): ...@@ -382,14 +400,15 @@ def encoder_parser(args):
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8", parser.add_argument(
action="store_true", "--use-fp8",
default=False, action="store_true",
help="Use FP8 for inference and training without recalibration") default=False,
parser.add_argument("--enable-sp", help="Use FP8 for inference and training without recalibration",
action="store_true", )
default=False, parser.add_argument(
help="Enable sequence parallelism.") "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
return parser.parse_args(args) return parser.parse_args(args)
......
...@@ -22,32 +22,35 @@ from jax.experimental.pjit import pjit ...@@ -22,32 +22,35 @@ from jax.experimental.pjit import pjit
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
DEVICE_DP_AXIS = 'data' DEVICE_DP_AXIS = "data"
PARAMS_KEY = 'params' PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + '_axes' PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = 'dropout' DROPOUT_KEY = "dropout"
INPUT_KEY = 'input_rng' INPUT_KEY = "input_rng"
class Net(nn.Module): class Net(nn.Module):
"""NLP Encoder""" """NLP Encoder"""
num_embed: int num_embed: int
@nn.compact @nn.compact
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te_flax.TransformerLayer, te_Encoder = partial(
hidden_size=256, te_flax.TransformerLayer,
mlp_hidden_size=1024, hidden_size=256,
num_attention_heads=8, mlp_hidden_size=1024,
hidden_dropout=0.1, num_attention_heads=8,
attention_dropout=0.1, hidden_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, attention_dropout=0.1,
layer_type=te_flax.TransformerLayerType.ENCODER, dropout_rng_name=DROPOUT_KEY,
self_attn_mask_type='padding', layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, self_attn_mask_type="padding",
dtype=jnp.bfloat16) enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
...@@ -82,20 +85,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): ...@@ -82,20 +85,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch.""" """Train for a single epoch."""
train_ds_size = len(train_ds['sentence']) train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size)) perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = [] epoch_loss = []
epoch_accuracy = [] epoch_accuracy = []
for perm in perms: for perm in perms:
batch_inputs = train_ds['sentence'][perm, ...] batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds['mask'][perm, ...] batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds['label'][perm, ...] batch_labels = train_ds["label"][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, state, loss, accuracy, var_collect = train_fn(
batch_labels, var_collect, rngs) state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
epoch_loss.append(loss) epoch_loss.append(loss)
epoch_accuracy.append(accuracy) epoch_accuracy.append(accuracy)
...@@ -121,7 +125,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -121,7 +125,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop.""" """Evaluation loop."""
test_ds_size = len(test_ds['sentence']) test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size valid_size = num_steps * batch_size
all_loss = [] all_loss = []
...@@ -129,9 +133,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): ...@@ -129,9 +133,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
for batch_start in range(0, valid_size, batch_size): for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size batch_end = batch_start + batch_size
batch_inputs = test_ds['sentence'][batch_start:batch_end] batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds['mask'][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss) all_loss.append(loss)
all_accuracy.append(accuracy) all_accuracy.append(accuracy)
...@@ -143,12 +147,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): ...@@ -143,12 +147,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len): def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers.""" """Convert tokens to numbers."""
nltk.download('punkt') nltk.download("punkt")
dataset_size = len(dataset['sentence']) dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']): for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence) tokens = nltk.word_tokenize(sentence)
tensor = output[j] tensor = output[j]
...@@ -168,9 +172,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -168,9 +172,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
new_dataset = { new_dataset = {
'sentence': output, "sentence": output,
'label': dataset['label'].astype(np.float32), "label": dataset["label"].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
} }
return new_dataset, vocab, word_id return new_dataset, vocab, word_id
...@@ -180,12 +184,12 @@ def get_datasets(max_seq_len): ...@@ -180,12 +184,12 @@ def get_datasets(max_seq_len):
vocab = {} vocab = {}
word_id = 0 word_id = 0
train_ds = load_dataset('glue', 'cola', split='train') train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type='np') train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset('glue', 'cola', split='validation') test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type='np') test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
...@@ -194,7 +198,8 @@ def check_fp8(state, var_collect, inputs, masks, labels): ...@@ -194,7 +198,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str( assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def get_params_pspec(sharding_rules, abs_var_collect): def get_params_pspec(sharding_rules, abs_var_collect):
...@@ -232,8 +237,7 @@ def train_and_evaluate(args): ...@@ -232,8 +237,7 @@ def train_and_evaluate(args):
num_gpu = jax.local_device_count() num_gpu = jax.local_device_count()
assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}" assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
assert args.test_batch_size % num_gpu == 0, \ assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
f"Test batch size needs to be multiple of {num_gpu}"
device_mesh = mesh_utils.create_device_mesh((num_gpu,)) device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)): with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)):
...@@ -247,8 +251,9 @@ def train_and_evaluate(args): ...@@ -247,8 +251,9 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8, with te.fp8_autocast(
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)): args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)
):
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
...@@ -260,18 +265,21 @@ def train_and_evaluate(args): ...@@ -260,18 +265,21 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec) in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = {key: params_pspec if key is PARAMS_KEY else None \ out_shardings = {
for key in abs_var_collect} key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks) var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply, state = train_state.TrainState.create(
params=params, apply_fn=encoder.apply, params=params, tx=optimizer
tx=optimizer) )
state_pspec = get_state_pspec(state, params_pspec) state_pspec = get_state_pspec(state, params_pspec)
labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,) labels_pspec = jax.sharding.PartitionSpec(
DEVICE_DP_AXIS,
)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None) out_shardings = (state_pspec, None, None, None)
...@@ -298,16 +306,20 @@ def train_and_evaluate(args): ...@@ -298,16 +306,20 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step) state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, test_loss, test_accuracy = eval_model(
var_collect, pjit_eval_step) state, test_ds, args.test_batch_size, var_collect, pjit_eval_step
)
print(f"Epoch: {epoch:>2} " print(
f"Train Loss: {train_loss:.6f} " f"Epoch: {epoch:>2} "
f"Train Accuracy: {train_accuracy:.6f} " f"Train Loss: {train_loss:.6f} "
f"Test Loss: {test_loss:.6f} " f"Train Accuracy: {train_accuracy:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ") f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy] return [train_loss, train_accuracy, test_loss, test_accuracy]
...@@ -357,10 +369,12 @@ def encoder_parser(args): ...@@ -357,10 +369,12 @@ def encoder_parser(args):
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8", parser.add_argument(
action="store_true", "--use-fp8",
default=False, action="store_true",
help="Use FP8 for inference and training without recalibration") default=False,
help="Use FP8 for inference and training without recalibration",
)
return parser.parse_args(args) return parser.parse_args(args)
......
...@@ -19,30 +19,33 @@ from flax.training import train_state ...@@ -19,30 +19,33 @@ from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
PARAMS_KEY = 'params' PARAMS_KEY = "params"
DROPOUT_KEY = 'dropout' DROPOUT_KEY = "dropout"
INPUT_KEY = 'input_rng' INPUT_KEY = "input_rng"
class Net(nn.Module): class Net(nn.Module):
"""NLP Encoder""" """NLP Encoder"""
num_embed: int num_embed: int
@nn.compact @nn.compact
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te_flax.TransformerLayer, te_Encoder = partial(
hidden_size=256, te_flax.TransformerLayer,
mlp_hidden_size=1024, hidden_size=256,
num_attention_heads=8, mlp_hidden_size=1024,
hidden_dropout=0.1, num_attention_heads=8,
attention_dropout=0.1, hidden_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, attention_dropout=0.1,
layer_type=te_flax.TransformerLayerType.ENCODER, dropout_rng_name=DROPOUT_KEY,
self_attn_mask_type='padding', layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, self_attn_mask_type="padding",
dtype=jnp.bfloat16) enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
...@@ -78,20 +81,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): ...@@ -78,20 +81,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def train_epoch(state, train_ds, batch_size, rngs, var_collect): def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch.""" """Train for a single epoch."""
train_ds_size = len(train_ds['sentence']) train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size)) perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = [] epoch_loss = []
epoch_accuracy = [] epoch_accuracy = []
for perm in perms: for perm in perms:
batch_inputs = train_ds['sentence'][perm, ...] batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds['mask'][perm, ...] batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds['label'][perm, ...] batch_labels = train_ds["label"][perm, ...]
state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks, state, loss, accuracy, var_collect = train_step(
batch_labels, var_collect, rngs) state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
epoch_loss.append(loss) epoch_loss.append(loss)
epoch_accuracy.append(accuracy) epoch_accuracy.append(accuracy)
...@@ -118,7 +122,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -118,7 +122,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(state, test_ds, batch_size, var_collect): def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop.""" """Evaluation loop."""
test_ds_size = len(test_ds['sentence']) test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size valid_size = num_steps * batch_size
all_loss = [] all_loss = []
...@@ -126,9 +130,9 @@ def eval_model(state, test_ds, batch_size, var_collect): ...@@ -126,9 +130,9 @@ def eval_model(state, test_ds, batch_size, var_collect):
for batch_start in range(0, valid_size, batch_size): for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size batch_end = batch_start + batch_size
batch_inputs = test_ds['sentence'][batch_start:batch_end] batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds['mask'][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect) loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss) all_loss.append(loss)
all_accuracy.append(accuracy) all_accuracy.append(accuracy)
...@@ -140,12 +144,12 @@ def eval_model(state, test_ds, batch_size, var_collect): ...@@ -140,12 +144,12 @@ def eval_model(state, test_ds, batch_size, var_collect):
def data_preprocess(dataset, vocab, word_id, max_seq_len): def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers.""" """Convert tokens to numbers."""
nltk.download('punkt') nltk.download("punkt")
dataset_size = len(dataset['sentence']) dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']): for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence) tokens = nltk.word_tokenize(sentence)
tensor = output[j] tensor = output[j]
...@@ -165,9 +169,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -165,9 +169,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
new_dataset = { new_dataset = {
'sentence': output, "sentence": output,
'label': dataset['label'].astype(np.float32), "label": dataset["label"].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
} }
return new_dataset, vocab, word_id return new_dataset, vocab, word_id
...@@ -177,12 +181,12 @@ def get_datasets(max_seq_len): ...@@ -177,12 +181,12 @@ def get_datasets(max_seq_len):
vocab = {} vocab = {}
word_id = 0 word_id = 0
train_ds = load_dataset('glue', 'cola', split='train') train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type='np') train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset('glue', 'cola', split='validation') test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type='np') test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
...@@ -191,7 +195,8 @@ def check_fp8(state, var_collect, inputs, masks, labels): ...@@ -191,7 +195,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str( assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def train_and_evaluate(args): def train_and_evaluate(args):
...@@ -214,9 +219,9 @@ def train_and_evaluate(args): ...@@ -214,9 +219,9 @@ def train_and_evaluate(args):
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
var_collect = encoder.init(init_rngs, inputs, masks) var_collect = encoder.init(init_rngs, inputs, masks)
tx = optax.adamw(args.lr) tx = optax.adamw(args.lr)
state = train_state.TrainState.create(apply_fn=encoder.apply, state = train_state.TrainState.create(
params=var_collect[PARAMS_KEY], apply_fn=encoder.apply, params=var_collect[PARAMS_KEY], tx=tx
tx=tx) )
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
...@@ -235,15 +240,18 @@ def train_and_evaluate(args): ...@@ -235,15 +240,18 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect) state, train_ds, args.batch_size, rngs, var_collect
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
print(f"Epoch: {epoch:>2} " print(
f"Train Loss: {train_loss:.6f} " f"Epoch: {epoch:>2} "
f"Train Accuracy: {train_accuracy:.6f} " f"Train Loss: {train_loss:.6f} "
f"Test Loss: {test_loss:.6f} " f"Train Accuracy: {train_accuracy:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ") f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy] return [train_loss, train_accuracy, test_loss, test_accuracy]
...@@ -293,10 +301,12 @@ def encoder_parser(args): ...@@ -293,10 +301,12 @@ def encoder_parser(args):
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8", parser.add_argument(
action="store_true", "--use-fp8",
default=False, action="store_true",
help="Use FP8 for inference and training without recalibration") default=False,
help="Use FP8 for inference and training without recalibration",
)
return parser.parse_args(args) return parser.parse_args(args)
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" MNIST training on single GPU""" """MNIST training on single GPU"""
import argparse import argparse
import unittest import unittest
from functools import partial from functools import partial
...@@ -20,13 +20,14 @@ import transformer_engine.jax.flax as te_flax ...@@ -20,13 +20,14 @@ import transformer_engine.jax.flax as te_flax
IMAGE_H = 28 IMAGE_H = 28
IMAGE_W = 28 IMAGE_W = 28
IMAGE_C = 1 IMAGE_C = 1
PARAMS_KEY = 'params' PARAMS_KEY = "params"
DROPOUT_KEY = 'dropout' DROPOUT_KEY = "dropout"
INPUT_KEY = 'input_rng' INPUT_KEY = "input_rng"
class Net(nn.Module): class Net(nn.Module):
"""CNN model for MNIST.""" """CNN model for MNIST."""
use_te: bool = False use_te: bool = False
@nn.compact @nn.compact
...@@ -83,17 +84,17 @@ def update_model(state, grads): ...@@ -83,17 +84,17 @@ def update_model(state, grads):
def train_epoch(state, train_ds, batch_size, rngs, var_collect): def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch.""" """Train for a single epoch."""
train_ds_size = len(train_ds['image']) train_ds_size = len(train_ds["image"])
steps_per_epoch = train_ds_size // batch_size steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size)) perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = [] epoch_loss = []
epoch_accuracy = [] epoch_accuracy = []
for perm in perms: for perm in perms:
batch_images = train_ds['image'][perm, ...] batch_images = train_ds["image"][perm, ...]
batch_labels = train_ds['label'][perm, ...] batch_labels = train_ds["label"][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs) grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs)
state, var_collect = update_model(state, grads) state, var_collect = update_model(state, grads)
epoch_loss.append(loss) epoch_loss.append(loss)
...@@ -106,7 +107,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): ...@@ -106,7 +107,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect):
def eval_model(state, test_ds, batch_size, var_collect): def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop.""" """Evaluation loop."""
test_ds_size = len(test_ds['image']) test_ds_size = len(test_ds["image"])
num_steps = test_ds_size // batch_size num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size valid_size = num_steps * batch_size
all_loss = [] all_loss = []
...@@ -114,8 +115,8 @@ def eval_model(state, test_ds, batch_size, var_collect): ...@@ -114,8 +115,8 @@ def eval_model(state, test_ds, batch_size, var_collect):
for batch_start in range(0, valid_size, batch_size): for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size batch_end = batch_start + batch_size
batch_images = test_ds['image'][batch_start:batch_end] batch_images = test_ds["image"][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end]
_, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect) _, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect)
all_loss.append(loss) all_loss.append(loss)
all_accuracy.append(accuracy) all_accuracy.append(accuracy)
...@@ -127,21 +128,21 @@ def eval_model(state, test_ds, batch_size, var_collect): ...@@ -127,21 +128,21 @@ def eval_model(state, test_ds, batch_size, var_collect):
def get_datasets(): def get_datasets():
"""Load MNIST train and test datasets into memory.""" """Load MNIST train and test datasets into memory."""
train_ds = load_dataset('mnist', split='train') train_ds = load_dataset("mnist", split="train")
train_ds.set_format(type='np') train_ds.set_format(type="np")
batch_size = train_ds['image'].shape[0] batch_size = train_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C) shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_train_ds = { new_train_ds = {
'image': train_ds['image'].astype(np.float32).reshape(shape) / 255., "image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0,
'label': train_ds['label'] "label": train_ds["label"],
} }
test_ds = load_dataset('mnist', split='test') test_ds = load_dataset("mnist", split="test")
test_ds.set_format(type='np') test_ds.set_format(type="np")
batch_size = test_ds['image'].shape[0] batch_size = test_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C) shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_test_ds = { new_test_ds = {
'image': test_ds['image'].astype(np.float32).reshape(shape) / 255., "image": test_ds["image"].astype(np.float32).reshape(shape) / 255.0,
'label': test_ds['label'] "label": test_ds["label"],
} }
return new_train_ds, new_test_ds return new_train_ds, new_test_ds
...@@ -149,8 +150,13 @@ def get_datasets(): ...@@ -149,8 +150,13 @@ def get_datasets():
def check_fp8(state, var_collect, input_shape, label_shape): def check_fp8(state, var_collect, input_shape, label_shape):
"Check if model includes FP8." "Check if model includes FP8."
assert "f8_" in str( assert "f8_" in str(
jax.make_jaxpr(apply_model)(state, jnp.empty(input_shape, dtype=jnp.bfloat16), jax.make_jaxpr(apply_model)(
jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect)) state,
jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16),
var_collect,
)
)
def train_and_evaluate(args): def train_and_evaluate(args):
...@@ -173,17 +179,21 @@ def train_and_evaluate(args): ...@@ -173,17 +179,21 @@ def train_and_evaluate(args):
cnn = Net(args.use_te) cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum) tx = optax.sgd(args.lr, args.momentum)
state = train_state.TrainState.create(apply_fn=cnn.apply, state = train_state.TrainState.create(
params=var_collect[PARAMS_KEY], apply_fn=cnn.apply, params=var_collect[PARAMS_KEY], tx=tx
tx=tx) )
if args.use_fp8: if args.use_fp8:
check_fp8(state, var_collect, input_shape, label_shape) check_fp8(state, var_collect, input_shape, label_shape)
if args.dry_run: if args.dry_run:
apply_model(state, jnp.empty(input_shape, dtype=jnp.bfloat16), apply_model(
jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect, state,
{DROPOUT_KEY: dropout_rng}) jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16),
var_collect,
{DROPOUT_KEY: dropout_rng},
)
print("PASSED") print("PASSED")
return None return None
...@@ -193,14 +203,17 @@ def train_and_evaluate(args): ...@@ -193,14 +203,17 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect) state, train_ds, args.batch_size, rngs, var_collect
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
print(f"Epoch: {epoch:>2} " print(
f"Train Loss: {train_loss:.6f} " f"Epoch: {epoch:>2} "
f"Train Accuracy: {train_accuracy:.6f} " f"Train Loss: {train_loss:.6f} "
f"Test Loss: {test_loss:.6f} " f"Train Accuracy: {train_accuracy:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ") f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy] return [train_loss, train_accuracy, test_loss, test_accuracy]
...@@ -250,15 +263,18 @@ def mnist_parser(args): ...@@ -250,15 +263,18 @@ def mnist_parser(args):
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--use-fp8", parser.add_argument(
action="store_true", "--use-fp8",
default=False, action="store_true",
help="Use FP8 for inference and training without recalibration. " \ default=False,
"It also enables Transformer Engine implicitly.") help=(
parser.add_argument("--use-te", "Use FP8 for inference and training without recalibration. "
action="store_true", "It also enables Transformer Engine implicitly."
default=False, ),
help="Use Transformer Engine") )
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
return parser.parse_args(args) return parser.parse_args(args)
......
...@@ -59,7 +59,9 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8): ...@@ -59,7 +59,9 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8):
model.train() model.train()
losses = [] losses = []
for batch_id, (data, labels) in enumerate(train_loader): for batch_id, (data, labels) in enumerate(train_loader):
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager with paddle.amp.auto_cast(
dtype="bfloat16", level="O2"
): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
outputs = model(data) outputs = model(data)
loss = F.cross_entropy(outputs, labels) loss = F.cross_entropy(outputs, labels)
...@@ -70,10 +72,12 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8): ...@@ -70,10 +72,12 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8):
optimizer.clear_gradients() optimizer.clear_gradients()
if batch_id % args.log_interval == 0: if batch_id % args.log_interval == 0:
print(f"Train Epoch: {epoch} " print(
f"[{batch_id * len(data)}/{len(train_loader.dataset)} " f"Train Epoch: {epoch} "
f"({100. * batch_id / len(train_loader):.0f}%)]\t" f"[{batch_id * len(data)}/{len(train_loader.dataset)} "
f"Loss: {loss.item():.6f}") f"({100. * batch_id / len(train_loader):.0f}%)]\t"
f"Loss: {loss.item():.6f}"
)
if args.dry_run: if args.dry_run:
return loss.item() return loss.item()
avg_loss = sum(losses) / len(losses) avg_loss = sum(losses) / len(losses)
...@@ -89,7 +93,9 @@ def evaluate(model, test_loader, epoch, use_fp8): ...@@ -89,7 +93,9 @@ def evaluate(model, test_loader, epoch, use_fp8):
with paddle.no_grad(): with paddle.no_grad():
for data, labels in test_loader: for data, labels in test_loader:
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager with paddle.amp.auto_cast(
dtype="bfloat16", level="O2"
): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
outputs = model(data) outputs = model(data)
acc = metric.compute(outputs, labels) acc = metric.compute(outputs, labels)
...@@ -104,7 +110,9 @@ def calibrate(model, test_loader): ...@@ -104,7 +110,9 @@ def calibrate(model, test_loader):
with paddle.no_grad(): with paddle.no_grad():
for data, _ in test_loader: for data, _ in test_loader:
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager with paddle.amp.auto_cast(
dtype="bfloat16", level="O2"
): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=False, calibrating=True): with te.fp8_autocast(enabled=False, calibrating=True):
_ = model(data) _ = model(data)
...@@ -160,20 +168,27 @@ def mnist_parser(args): ...@@ -160,20 +168,27 @@ def mnist_parser(args):
metavar="N", metavar="N",
help="how many batches to wait before logging training status", help="how many batches to wait before logging training status",
) )
parser.add_argument("--use-fp8", parser.add_argument(
action="store_true", "--use-fp8",
default=False, action="store_true",
help="Use FP8 for inference and training without recalibration. " \ default=False,
"It also enables Transformer Engine implicitly.") help=(
parser.add_argument("--use-fp8-infer", "Use FP8 for inference and training without recalibration. "
action="store_true", "It also enables Transformer Engine implicitly."
default=False, ),
help="Use FP8 for inference only. If not using FP8 for training, " )
"calibration is performed for FP8 infernece.") parser.add_argument(
parser.add_argument("--use-te", "--use-fp8-infer",
action="store_true", action="store_true",
default=False, default=False,
help="Use Transformer Engine") help=(
"Use FP8 for inference only. If not using FP8 for training, "
"calibration is performed for FP8 infernece."
),
)
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
args = parser.parse_args(args) args = parser.parse_args(args)
return args return args
...@@ -185,9 +200,9 @@ def train_and_evaluate(args): ...@@ -185,9 +200,9 @@ def train_and_evaluate(args):
paddle.seed(args.seed) paddle.seed(args.seed)
# Load MNIST dataset # Load MNIST dataset
transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW') transform = Normalize(mean=[127.5], std=[127.5], data_format="CHW")
train_dataset = MNIST(mode='train', transform=transform) train_dataset = MNIST(mode="train", transform=transform)
val_dataset = MNIST(mode='test', transform=transform) val_dataset = MNIST(mode="test", transform=transform)
# Define data loaders # Define data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
...@@ -198,7 +213,7 @@ def train_and_evaluate(args): ...@@ -198,7 +213,7 @@ def train_and_evaluate(args):
optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters()) optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
# Cast model to BF16 # Cast model to BF16
model = paddle.amp.decorate(models=model, level='O2', dtype='bfloat16') model = paddle.amp.decorate(models=model, level="O2", dtype="bfloat16")
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8) loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8)
...@@ -209,7 +224,7 @@ def train_and_evaluate(args): ...@@ -209,7 +224,7 @@ def train_and_evaluate(args):
if args.save_model or args.use_fp8_infer: if args.save_model or args.use_fp8_infer:
paddle.save(model.state_dict(), "mnist_cnn.pdparams") paddle.save(model.state_dict(), "mnist_cnn.pdparams")
print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8)) print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8))
weights = paddle.load("mnist_cnn.pdparams") weights = paddle.load("mnist_cnn.pdparams")
model.set_state_dict(weights) model.set_state_dict(weights)
acc = evaluate(model, val_loader, 0, args.use_fp8) acc = evaluate(model, val_loader, 0, args.use_fp8)
...@@ -235,8 +250,10 @@ class TestMNIST(unittest.TestCase): ...@@ -235,8 +250,10 @@ class TestMNIST(unittest.TestCase):
assert actual[0] < desired_traing_loss assert actual[0] < desired_traing_loss
assert actual[1] > desired_test_accuracy assert actual[1] > desired_test_accuracy
@unittest.skipIf(paddle.device.cuda.get_device_capability() < (8, 0), @unittest.skipIf(
"BF16 MNIST example requires Ampere+ GPU") paddle.device.cuda.get_device_capability() < (8, 0),
"BF16 MNIST example requires Ampere+ GPU",
)
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.use_te = True self.args.use_te = True
......
...@@ -15,62 +15,77 @@ import torch.distributed as dist ...@@ -15,62 +15,77 @@ import torch.distributed as dist
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling from transformer_engine.common.recipe import Format, DelayedScaling
def parse_args(argv=None, namespace=None): def parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers.") description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
parser.add_argument('-i', "--num-iters", type=int, default=5, )
help="Number of dummy 'training' iterations.") parser.add_argument(
parser.add_argument('-b', "--batch-size", type=int, default=2, "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
help="Input batch size.") )
parser.add_argument('-s', "--seq-length", type=int, default=2048, parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
help="Input sequence length.") parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
parser.add_argument('-n', "--num-heads", type=int, default=64, parser.add_argument(
help="Number of attention heads.") "-n", "--num-heads", type=int, default=64, help="Number of attention heads."
parser.add_argument('-d', "--head-dim", type=int, default=128, )
help="Dimension of each attention head.") parser.add_argument(
parser.add_argument("--mlp-expansion-factor", type=int, default=4, "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
help="MLP block intermediate size as a factor of hidden dimension.") )
parser.add_argument("--seed", type=int, default=1234, parser.add_argument(
help="RNG seed.") "--mlp-expansion-factor",
parser.add_argument("--fp8", action="store_true", default=False, type=int,
help="Enables the te.fp8_autocast() context.") default=4,
parser.add_argument("--no-comm-overlap", action="store_true", default=False, help="MLP block intermediate size as a factor of hidden dimension.",
help="Disable the comm+GEMM overlap.") )
parser.add_argument('-v', "--verbose", action="store_true", default=False) parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--no-comm-overlap",
action="store_true",
default=False,
help="Disable the comm+GEMM overlap.",
)
parser.add_argument("-v", "--verbose", action="store_true", default=False)
return parser.parse_args(argv, namespace) return parser.parse_args(argv, namespace)
def train(opts): def train(opts):
WORLD_RANK = int(os.getenv("RANK")) WORLD_RANK = int(os.getenv("RANK"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE")) WORLD_SIZE = int(os.getenv("WORLD_SIZE"))
def dist_print(msg, end='\n', all_ranks=False):
def dist_print(msg, end="\n", all_ranks=False):
if WORLD_RANK == 0 or all_ranks: if WORLD_RANK == 0 or all_ranks:
print(f"[RANK-{WORLD_RANK}] {msg}", end=end) print(f"[RANK-{WORLD_RANK}] {msg}", end=end)
# Seed RNG # Seed RNG
torch.cuda.set_device(WORLD_RANK) torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(opts.seed+WORLD_RANK) torch.manual_seed(opts.seed + WORLD_RANK)
torch.cuda.manual_seed(opts.seed+WORLD_RANK) torch.cuda.manual_seed(opts.seed + WORLD_RANK)
# Initialize torch.distributed global process group and get TP group # Initialize torch.distributed global process group and get TP group
dist.init_process_group(backend="nccl", dist.init_process_group(
rank=WORLD_RANK, backend="nccl",
world_size=WORLD_SIZE, rank=WORLD_RANK,
device_id=torch.device(f'cuda:{WORLD_RANK}')) world_size=WORLD_SIZE,
device_id=torch.device(f"cuda:{WORLD_RANK}"),
)
tp_group = dist.new_group(backend="nccl") tp_group = dist.new_group(backend="nccl")
tp_size = dist.get_world_size(tp_group) tp_size = dist.get_world_size(tp_group)
# Intialize userbuffers # Intialize userbuffers
ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad
'method': 'ring_exchange', "method": "ring_exchange",
'num_splits' : 8, "num_splits": 8,
'num_sm' : 1, "num_sm": 1,
'set_sm_margin' : False, "set_sm_margin": False,
} }
rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop
'method': 'ring_exchange', "method": "ring_exchange",
'num_splits' : 4, "num_splits": 4,
'num_sm' : 1, "num_sm": 1,
'set_sm_margin' : True, "set_sm_margin": True,
} }
hidden_size = opts.num_heads * opts.head_dim hidden_size = opts.num_heads * opts.head_dim
batched_size = opts.seq_length * opts.batch_size batched_size = opts.seq_length * opts.batch_size
...@@ -78,30 +93,31 @@ def train(opts): ...@@ -78,30 +93,31 @@ def train(opts):
te.initialize_ub( te.initialize_ub(
[batched_size, hidden_size], [batched_size, hidden_size],
tp_group, tp_group,
use_fp8 = opts.fp8, use_fp8=opts.fp8,
dtype = torch.bfloat16, dtype=torch.bfloat16,
ub_cfgs = { ub_cfgs={
'fc1_fprop': ag_cfg, "fc1_fprop": ag_cfg,
'fc1_dgrad': rs_cfg, "fc1_dgrad": rs_cfg,
'fc2_fprop': rs_cfg, "fc2_fprop": rs_cfg,
'fc2_dgrad': ag_cfg, "fc2_dgrad": ag_cfg,
}, },
) )
# #
model = te.LayerNormMLP( model = te.LayerNormMLP(
hidden_size, opts.mlp_expansion_factor * hidden_size, hidden_size,
params_dtype = torch.bfloat16, opts.mlp_expansion_factor * hidden_size,
device = 'cuda', params_dtype=torch.bfloat16,
tp_group = tp_group, device="cuda",
tp_size = tp_size, tp_group=tp_group,
set_parallel_mode = True, tp_size=tp_size,
sequence_parallel = True, # this is required for comm+GEMM overlap set_parallel_mode=True,
seq_length = opts.seq_length, sequence_parallel=True, # this is required for comm+GEMM overlap
micro_batch_size = opts.batch_size, seq_length=opts.seq_length,
ub_overlap_rs_dgrad = not opts.no_comm_overlap, micro_batch_size=opts.batch_size,
ub_overlap_rs = not opts.no_comm_overlap, ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_overlap_ag = not opts.no_comm_overlap, ub_overlap_rs=not opts.no_comm_overlap,
ub_overlap_ag=not opts.no_comm_overlap,
) )
# Initialize optimizer with model parameters # Initialize optimizer with model parameters
...@@ -109,16 +125,19 @@ def train(opts): ...@@ -109,16 +125,19 @@ def train(opts):
# Fp8 recipe setup # Fp8 recipe setup
fp8_format = Format.HYBRID fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
amax_compute_algo="max")
# Start dummy "training" iterations # Start dummy "training" iterations
for i in range(opts.num_iters): for i in range(opts.num_iters):
dist_print(f"Iter {i+1}", all_ranks=opts.verbose) dist_print(f"Iter {i+1}", all_ranks=opts.verbose)
dist_print("|-- Generate random input batch", all_ranks=opts.verbose) dist_print("|-- Generate random input batch", all_ranks=opts.verbose)
x = torch.rand((opts.seq_length // tp_size, opts.batch_size, hidden_size), x = torch.rand(
dtype=torch.bfloat16, device='cuda', requires_grad=True) (opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
dist_print("|-- Forward pass", all_ranks=opts.verbose) dist_print("|-- Forward pass", all_ranks=opts.verbose)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group): with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group):
...@@ -135,17 +154,15 @@ def train(opts): ...@@ -135,17 +154,15 @@ def train(opts):
te.destroy_ub() te.destroy_ub()
dist.destroy_process_group() dist.destroy_process_group()
if __name__ == "__main__": if __name__ == "__main__":
if "TORCHELASTIC_RUN_ID" in os.environ.keys(): if "TORCHELASTIC_RUN_ID" in os.environ.keys():
args = parse_args() args = parse_args()
train(args) train(args)
else: else:
subprocess.run( subprocess.run(
[ ["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv],
'torchrun', f'--nproc-per-node={torch.cuda.device_count()}',
*sys.argv
],
env=os.environ, env=os.environ,
check=True check=True,
) )
os._exit(0) os._exit(0)
...@@ -14,7 +14,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision ...@@ -14,7 +14,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing, apply_activation_checkpointing,
checkpoint_wrapper checkpoint_wrapper,
) )
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
...@@ -29,46 +29,56 @@ rng_seed = 1234 ...@@ -29,46 +29,56 @@ rng_seed = 1234
torch.manual_seed(rng_seed) torch.manual_seed(rng_seed)
torch.cuda.manual_seed(rng_seed) torch.cuda.manual_seed(rng_seed)
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker() CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
CUDA_RNG_STATES_TRACKER.add('model-parallel-rng', rng_seed) CUDA_RNG_STATES_TRACKER.add("model-parallel-rng", rng_seed)
def get_cuda_rng_tracker(): def get_cuda_rng_tracker():
return CUDA_RNG_STATES_TRACKER return CUDA_RNG_STATES_TRACKER
def apply_fsdp_checkpointing(model, blocks): def apply_fsdp_checkpointing(model, blocks):
"""apply activation checkpointing to model """apply activation checkpointing to model
returns None as model is updated directly returns None as model is updated directly
""" """
wrapper = lambda m: checkpoint_wrapper(m, wrapper = lambda m: checkpoint_wrapper(
checkpoint_fn=te.distributed.checkpoint, m,
use_reentrant=False, checkpoint_fn=te.distributed.checkpoint,
get_rng_state_tracker=get_cuda_rng_tracker) use_reentrant=False,
get_rng_state_tracker=get_cuda_rng_tracker,
)
check_fn = lambda submodule: isinstance(submodule, blocks) check_fn = lambda submodule: isinstance(submodule, blocks)
apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
def lowercase(s): def lowercase(s):
return str(s).lower() return str(s).lower()
def torch_dtype(d): def torch_dtype(d):
typemap = { typemap = {
'fp32' : torch.float32, "fp32": torch.float32,
'float32' : torch.float32, "float32": torch.float32,
'fp16' : torch.float16, "fp16": torch.float16,
'float16' : torch.float16, "float16": torch.float16,
'bf16' : torch.bfloat16, "bf16": torch.bfloat16,
'bfloat16' : torch.bfloat16 "bfloat16": torch.bfloat16,
} }
if lowercase(d) not in typemap.keys(): if lowercase(d) not in typemap.keys():
raise TypeError raise TypeError
return typemap[lowercase(d)] return typemap[lowercase(d)]
te_layer_map = { te_layer_map = {
'linear': te.Linear, "linear": te.Linear,
'layernorm': te.LayerNorm, "layernorm": te.LayerNorm,
'rmsnorm': te.RMSNorm, "rmsnorm": te.RMSNorm,
'layernormlinear': te.LayerNormLinear, "layernormlinear": te.LayerNormLinear,
'layernormmlp': te.LayerNormMLP, "layernormmlp": te.LayerNormMLP,
'multiheadattention': te.MultiheadAttention, "multiheadattention": te.MultiheadAttention,
'transformerlayer': te.TransformerLayer "transformerlayer": te.TransformerLayer,
} }
def te_layer(l): def te_layer(l):
if l is not None: if l is not None:
if lowercase(l) not in te_layer_map.keys(): if lowercase(l) not in te_layer_map.keys():
...@@ -76,74 +86,120 @@ def te_layer(l): ...@@ -76,74 +86,120 @@ def te_layer(l):
return te_layer_map[lowercase(l)] return te_layer_map[lowercase(l)]
return None return None
def get_layer_args(opts): def get_layer_args(opts):
hidden_size = opts.num_heads * opts.head_dim hidden_size = opts.num_heads * opts.head_dim
layer_args = (hidden_size, ) layer_args = (hidden_size,)
layer_kwargs = { layer_kwargs = {
'params_dtype': opts.dtype, "params_dtype": opts.dtype,
'device': 'cuda' if opts.no_defer_init else 'meta', "device": "cuda" if opts.no_defer_init else "meta",
'get_rng_state_tracker': get_cuda_rng_tracker, "get_rng_state_tracker": get_cuda_rng_tracker,
} }
if opts.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]: if opts.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 3 * hidden_size if opts.num_layers == 1 else hidden_size ffn_hidden_size = 3 * hidden_size if opts.num_layers == 1 else hidden_size
layer_args += (ffn_hidden_size, ) layer_args += (ffn_hidden_size,)
layer_kwargs['bias'] = True layer_kwargs["bias"] = True
if opts.layer_type == te.LayerNormMLP: if opts.layer_type == te.LayerNormMLP:
layer_kwargs['seq_length'] = opts.seq_length layer_kwargs["seq_length"] = opts.seq_length
elif opts.layer_type == te.MultiheadAttention: elif opts.layer_type == te.MultiheadAttention:
layer_args += (opts.num_heads, ) layer_args += (opts.num_heads,)
layer_kwargs['fuse_qkv_params'] = True layer_kwargs["fuse_qkv_params"] = True
layer_kwargs['input_layernorm'] = True layer_kwargs["input_layernorm"] = True
elif opts.layer_type == te.TransformerLayer: elif opts.layer_type == te.TransformerLayer:
layer_args += (3 * hidden_size, opts.num_heads) layer_args += (3 * hidden_size, opts.num_heads)
layer_kwargs['fuse_qkv_params'] = True layer_kwargs["fuse_qkv_params"] = True
layer_kwargs['seq_length'] = opts.seq_length layer_kwargs["seq_length"] = opts.seq_length
return layer_args, layer_kwargs return layer_args, layer_kwargs
def parse_fsdp_args(): def parse_fsdp_args():
parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " + parser = argparse.ArgumentParser(
"torch.distributed.fsdp.FullyShardedDataParallel strategy.") description="Run Transformer Engine modules with the "
parser.add_argument('-v', "--verbose", action="store_true", default=False, + "torch.distributed.fsdp.FullyShardedDataParallel strategy."
help="Print out information from all GPUs instead of only the root GPU-0.") )
parser.add_argument('-b', "--batch-size", type=int, default=32, parser.add_argument(
help="Input batch size.") "-v",
parser.add_argument('-s', "--seq-length", type=int, default=1048, "--verbose",
help="Input sequence length.") action="store_true",
parser.add_argument('-n', "--num-heads", type=int, default=16, default=False,
help="Number of attention heads.") help="Print out information from all GPUs instead of only the root GPU-0.",
parser.add_argument('-d', "--head-dim", type=int, default=128, )
help="Dimension of each attention head (number of KV channels).") parser.add_argument("-b", "--batch-size", type=int, default=32, help="Input batch size.")
parser.add_argument('-i', "--num-iters", type=int, default=5, parser.add_argument("-s", "--seq-length", type=int, default=1048, help="Input sequence length.")
help="Number of dummy 'training' iterations.") parser.add_argument(
parser.add_argument('-k', "--num-layers", type=int, default=3, "-n", "--num-heads", type=int, default=16, help="Number of attention heads."
help="Number of modules chained together with nn.Sequential.") )
parser.add_argument("--layer-type", type=te_layer, default=te.TransformerLayer, parser.add_argument(
choices=list(te_layer_map.values()), "-d",
help="TE module type used to construct the test model.") "--head-dim",
parser.add_argument("--seed", type=int, default=1234, type=int,
help="PyTorch RNG seed.") default=128,
parser.add_argument("--profile-memory", action="store_true", help="Dimension of each attention head (number of KV channels).",
help="Enable memory profiling via torch.profiler.profile().") )
parser.add_argument("--profile-name", type=str, default=None, parser.add_argument(
help="File path for memory profiling.") "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
parser.add_argument("--checkpoint-layer", type=te_layer, default=None, )
help="Recompute activations of the selected layer during the backward " + \ parser.add_argument(
"pass instead of saving.") "-k",
parser.add_argument("--no-fp8", action="store_true", default=False, "--num-layers",
help="Disables the te.fp8_autocast() context.") type=int,
parser.add_argument("--no-defer-init", action="store_true", default=3,
help="Defer module parameter initialization until after FSDP sharding.") help="Number of modules chained together with nn.Sequential.",
parser.add_argument("--no-te-fsdp", action="store_true", )
help="Disable sharding of intermediate/activation tensors in TE modules.") parser.add_argument(
parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16, "--layer-type",
help="Data type for input tensor and Transformer Engine module parameters.") type=te_layer,
default=te.TransformerLayer,
choices=list(te_layer_map.values()),
help="TE module type used to construct the test model.",
)
parser.add_argument("--seed", type=int, default=1234, help="PyTorch RNG seed.")
parser.add_argument(
"--profile-memory",
action="store_true",
help="Enable memory profiling via torch.profiler.profile().",
)
parser.add_argument(
"--profile-name", type=str, default=None, help="File path for memory profiling."
)
parser.add_argument(
"--checkpoint-layer",
type=te_layer,
default=None,
help="Recompute activations of the selected layer during the backward "
+ "pass instead of saving.",
)
parser.add_argument(
"--no-fp8",
action="store_true",
default=False,
help="Disables the te.fp8_autocast() context.",
)
parser.add_argument(
"--no-defer-init",
action="store_true",
help="Defer module parameter initialization until after FSDP sharding.",
)
parser.add_argument(
"--no-te-fsdp",
action="store_true",
help="Disable sharding of intermediate/activation tensors in TE modules.",
)
parser.add_argument(
"--dtype",
type=torch_dtype,
default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.",
)
return parser.parse_args() return parser.parse_args()
def dist_print(text, all_ranks=False, no_new_line=False): def dist_print(text, all_ranks=False, no_new_line=False):
if LOCAL_RANK == 0 or all_ranks: if LOCAL_RANK == 0 or all_ranks:
end = '' if no_new_line else '\n' end = "" if no_new_line else "\n"
print(f"[GPU-{LOCAL_RANK}] " + text, end=end) print(f"[GPU-{LOCAL_RANK}] " + text, end=end)
def train(opts): def train(opts):
# Initialize torch.distributed global process group # Initialize torch.distributed global process group
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
...@@ -157,7 +213,7 @@ def train(opts): ...@@ -157,7 +213,7 @@ def train(opts):
te_layer_list = [] te_layer_list = []
for i in range(opts.num_layers): for i in range(opts.num_layers):
if opts.layer_type in [te.MultiheadAttention, te.TransformerLayer]: if opts.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
layer_kwargs['layer_number'] = i+1 layer_kwargs["layer_number"] = i + 1
te_layer_list.append(opts.layer_type(*layer_args, **layer_kwargs)) te_layer_list.append(opts.layer_type(*layer_args, **layer_kwargs))
te_model = nn.Sequential(*te_layer_list) te_model = nn.Sequential(*te_layer_list)
else: else:
...@@ -171,20 +227,23 @@ def train(opts): ...@@ -171,20 +227,23 @@ def train(opts):
# Wrap the model with FSDP # Wrap the model with FSDP
# NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and # NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
# controls all communication. # controls all communication.
all_gpus = dist.new_group(backend='nccl') all_gpus = dist.new_group(backend="nccl")
fsdp_wrap_policy = always_wrap_policy fsdp_wrap_policy = always_wrap_policy
if opts.layer_type == te.TransformerLayer: if opts.layer_type == te.TransformerLayer:
# NOTE: FSDP causes illegal memory access without this special policy for Transformers # NOTE: FSDP causes illegal memory access without this special policy for Transformers
fsdp_wrap_policy = partial(transformer_auto_wrap_policy, fsdp_wrap_policy = partial(
transformer_layer_cls={te.TransformerLayer}) transformer_auto_wrap_policy, transformer_layer_cls={te.TransformerLayer}
te_model = FullyShardedDataParallel(te_model, )
process_group=all_gpus, te_model = FullyShardedDataParallel(
use_orig_params=True, te_model,
mixed_precision=MixedPrecision( process_group=all_gpus,
param_dtype=opts.dtype, use_orig_params=True,
reduce_dtype=torch.float32, mixed_precision=MixedPrecision(
), param_dtype=opts.dtype,
auto_wrap_policy=fsdp_wrap_policy) reduce_dtype=torch.float32,
),
auto_wrap_policy=fsdp_wrap_policy,
)
if opts.checkpoint_layer is not None: if opts.checkpoint_layer is not None:
# Recompute the activations of the selected layer during the backward pass instead of # Recompute the activations of the selected layer during the backward pass instead of
...@@ -218,8 +277,13 @@ def train(opts): ...@@ -218,8 +277,13 @@ def train(opts):
for i in range(opts.num_iters): for i in range(opts.num_iters):
# Generate a random input batch # Generate a random input batch
x = torch.rand(opts.seq_length, opts.batch_size, opts.num_heads*opts.head_dim, x = torch.rand(
dtype=opts.dtype, device='cuda') opts.seq_length,
opts.batch_size,
opts.num_heads * opts.head_dim,
dtype=opts.dtype,
device="cuda",
)
# fp8_autocast needs to be given the FSDP process group for amax reductions # fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus): with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
y = te_model(x) y = te_model(x)
...@@ -230,7 +294,6 @@ def train(opts): ...@@ -230,7 +294,6 @@ def train(opts):
optim.zero_grad(set_to_none=True) optim.zero_grad(set_to_none=True)
del x del x
if opts.profile_memory: if opts.profile_memory:
torch.cuda.memory._dump_snapshot(f"gpu{LOCAL_RANK}_{opts.profile_name}.pickle") torch.cuda.memory._dump_snapshot(f"gpu{LOCAL_RANK}_{opts.profile_name}.pickle")
torch.cuda.memory._record_memory_history(enabled=None) torch.cuda.memory._record_memory_history(enabled=None)
...@@ -238,7 +301,7 @@ def train(opts): ...@@ -238,7 +301,7 @@ def train(opts):
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated() peak_mem = torch.cuda.max_memory_allocated()
train_time = start.elapsed_time(end)/1000. train_time = start.elapsed_time(end) / 1000.0
dist_print(f"Training Time: {train_time}s") dist_print(f"Training Time: {train_time}s")
dist_print(f"Avg. Iter. Time: {train_time / opts.num_iters}s") dist_print(f"Avg. Iter. Time: {train_time / opts.num_iters}s")
dist_print(f"Peak Memory Use: {peak_mem * 1e-6}MBs") dist_print(f"Peak Memory Use: {peak_mem * 1e-6}MBs")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment