Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import (
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention
_run_dot_product_attention,
)
pd.set_option("display.precision", 4)
......@@ -28,7 +28,7 @@ ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
......@@ -38,16 +38,17 @@ is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
......@@ -57,17 +58,31 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
for i in range(warmup_iters):
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
torch.cuda.cudart().cudaProfilerStart()
......@@ -76,8 +91,15 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
if fused_attn_supported:
for i in range(num_iters):
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
torch.cuda.synchronize()
fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0
......@@ -87,81 +109,113 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
if flash_attn_supported:
for i in range(num_iters):
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
torch.cuda.synchronize()
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0
df = pd.read_csv('times.csv')
df = pd.concat([
df,
pd.DataFrame(
[[fused_attn_time*1e3/num_iters, 0, 0, 0,
flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)],
ignore_index=True
)
df.to_csv('times.csv',index=False)
df = pd.read_csv("times.csv")
df = pd.concat(
[
df,
pd.DataFrame(
[
[
fused_attn_time * 1e3 / num_iters,
0,
0,
0,
flash_attn_time * 1e3 / num_iters,
0,
0,
0,
0,
]
],
columns=df.columns,
),
],
ignore_index=True,
)
df.to_csv("times.csv", index=False)
torch.cuda.cudart().cudaProfilerStop()
def parse_results(per_cudnn, per_flash, model):
filename = f'prof_{model}_cuda_gpu_trace.csv'
df = pd.read_csv(os.path.join('./',filename))
df_times = pd.read_csv('times.csv')
row = len(df_times.index)-1
filename = f"prof_{model}_cuda_gpu_trace.csv"
df = pd.read_csv(os.path.join("./", filename))
df_times = pd.read_csv("times.csv")
row = len(df_times.index) - 1
if per_cudnn > 0:
t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy()
t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6
df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6
df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6
df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
if per_flash > 0:
t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy()
t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6
df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6
if per_cudnn > 0 and per_flash > 0:
df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)']
df_times.to_csv('times.csv',index=False)
df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"]
/ df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"]
)
df_times.to_csv("times.csv", index=False)
def main():
times = pd.DataFrame(
columns=[
'FusedAttention Module',
'FusedAttention Kernels (fwd)',
'FusedAttention Kernels (bwd)',
'FusedAttention Kernels (fwd+bwd)',
'FlashAttention Module',
'FlashAttention Kernels (fwd)',
'FlashAttention Kernels (bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)',
])
times.to_csv('times.csv',index=False)
columns=[
"FusedAttention Module",
"FusedAttention Kernels (fwd)",
"FusedAttention Kernels (bwd)",
"FusedAttention Kernels (fwd+bwd)",
"FlashAttention Module",
"FlashAttention Kernels (fwd)",
"FlashAttention Kernels (bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
]
)
times.to_csv("times.csv", index=False)
device_id = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(device_id)
print(f"Device {device_id}: "
print(
f"Device {device_id}: "
f"{device_properties.name} GPU, "
f"sm{device_properties.major}{device_properties.minor} compute capability, "
f"{device_properties.total_memory/1024**3:.1f}GB memory")
f"{device_properties.total_memory/1024**3:.1f}GB memory"
)
for model in model_configs.keys():
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
config,
dtype,
qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...')
print(
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...'
)
prof_cmd = [
"nsys",
......@@ -175,8 +229,8 @@ def main():
f""" "import benchmark_attention;""",
f"""benchmark_attention.benchmark_dot_product_attention("""
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
]
prof_cmd = ' '.join(prof_cmd)
]
prof_cmd = " ".join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
stats_cmd = [
"nsys",
......@@ -190,17 +244,17 @@ def main():
"--force-export=true",
f"--output=prof_{model}",
f"prof_{model}.nsys-rep",
]
]
if fused_attn_supported:
num_kernels_cudnn = 4
if config.attn_bias_type == 'post_scale_bias':
num_kernels_cudnn = num_kernels_cudnn+1
if config.attn_bias_type == "post_scale_bias":
num_kernels_cudnn = num_kernels_cudnn + 1
if config.num_heads != config.num_gqa_groups:
num_kernels_cudnn = num_kernels_cudnn+2
num_kernels_cudnn = num_kernels_cudnn + 2
else:
num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0
stats_cmd = ' '.join(stats_cmd)
stats_cmd = " ".join(stats_cmd)
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
parse_cmd = [
"python",
......@@ -208,18 +262,23 @@ def main():
f""" "import benchmark_attention;""",
f"""benchmark_attention.parse_results("""
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
]
parse_cmd = ' '.join(parse_cmd)
]
parse_cmd = " ".join(parse_cmd)
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
df_times = pd.read_csv('times.csv')
df_times = pd.read_csv("times.csv")
df_times.index = list(model_configs.keys())
a=df_times[['FusedAttention Kernels (fwd+bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)']]
a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup']
a = df_times[
[
"FusedAttention Kernels (fwd+bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
]
]
a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
print()
print(a)
if __name__ == "__main__":
main()
......@@ -64,6 +64,7 @@ class CMakeExtension(setuptools.Extension):
configure_command.append("-GNinja")
import pybind11
pybind11_dir = Path(pybind11.__file__).resolve().parent
pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
......@@ -130,6 +131,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
else:
# Only during release sdist build.
import transformer_engine
search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
del transformer_engine
......@@ -142,8 +144,9 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
# Figure out stub file path
module_name = paddle_ext.name
assert module_name.endswith("_pd_"), \
"Expected Paddle extension module to end with '_pd_'"
assert module_name.endswith(
"_pd_"
), "Expected Paddle extension module to end with '_pd_'"
stub_name = module_name[:-4] # remove '_pd_'
stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py")
Path(stub_path).parent.mkdir(exist_ok=True, parents=True)
......@@ -158,6 +161,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
# Write stub file
print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
from paddle.utils.cpp_extension.extension_utils import custom_write_stub
custom_write_stub(lib_name, stub_path)
# Ensure that binaries are not in global package space.
......@@ -182,13 +186,14 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
# extra_compile_args is a dict.
for ext in self.extensions:
if isinstance(ext.extra_compile_args, dict):
for target in ['cxx', 'nvcc']:
for target in ["cxx", "nvcc"]:
if target not in ext.extra_compile_args.keys():
ext.extra_compile_args[target] = []
# Define new _compile method that redirects to NVCC for .cu and .cuh files.
original_compile_fn = self.compiler._compile
self.compiler.src_extensions += ['.cu', '.cuh']
self.compiler.src_extensions += [".cu", ".cuh"]
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
# Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs)
......@@ -197,31 +202,31 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
_, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in ['.cu', '.cuh']:
self.compiler.set_executable('compiler_so', str(nvcc_bin))
if os.path.splitext(src)[1] in [".cu", ".cuh"]:
self.compiler.set_executable("compiler_so", str(nvcc_bin))
if isinstance(cflags, dict):
cflags = cflags['nvcc']
cflags = cflags["nvcc"]
# Add -fPIC if not already specified
if not any('-fPIC' in flag for flag in cflags):
cflags.extend(['--compiler-options', "'-fPIC'"])
if not any("-fPIC" in flag for flag in cflags):
cflags.extend(["--compiler-options", "'-fPIC'"])
# Forward unknown options
if not any('--forward-unknown-opts' in flag for flag in cflags):
cflags.append('--forward-unknown-opts')
if not any("--forward-unknown-opts" in flag for flag in cflags):
cflags.append("--forward-unknown-opts")
elif isinstance(cflags, dict):
cflags = cflags['cxx']
cflags = cflags["cxx"]
# Append -std=c++17 if not already in flags
if not any(flag.startswith('-std=') for flag in cflags):
cflags.append('-std=c++17')
if not any(flag.startswith("-std=") for flag in cflags):
cflags.append("-std=c++17")
return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)
finally:
# Put the original compiler back in place.
self.compiler.set_executable('compiler_so', original_compiler)
self.compiler.set_executable("compiler_so", original_compiler)
self.compiler._compile = _compile_fn
......
......@@ -36,8 +36,8 @@ def setup_jax_extension(
]
# Compile flags
cxx_flags = [ "-O3" ]
nvcc_flags = [ "-O3" ]
cxx_flags = ["-O3"]
nvcc_flags = ["-O3"]
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
......@@ -47,9 +47,9 @@ def setup_jax_extension(
def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
cxx_flags = self.extra_compile_args.pop('cxx', [])
cxx_flags = self.extra_compile_args.pop("cxx", [])
cxx_flags += flags
self.extra_compile_args['cxx'] = cxx_flags
self.extra_compile_args["cxx"] = cxx_flags
else:
self.extra_compile_args[:0] = flags
......@@ -57,8 +57,5 @@ def setup_jax_extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags
},
extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags},
)
......@@ -76,11 +76,12 @@ def setup_pytorch_extension(
# Libraries -- PyTorch CUDAExtension links to libcudart.so but not to libcuda.so
cuda_home, _ = cuda_path()
library_dirs = [ cuda_home / "compat" / "lib" ]
libraries = [ "cuda" ]
library_dirs = [cuda_home / "compat" / "lib"]
libraries = ["cuda"]
if os.getenv("UB_MPI_BOOTSTRAP"):
assert os.getenv("MPI_HOME") is not None, \
"MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1"
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1"
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
cxx_flags.append("-DUB_MPI_BOOTSTRAP")
......@@ -95,12 +96,12 @@ def setup_pytorch_extension(
return CUDAExtension(
name="transformer_engine_torch",
sources=[ str(src) for src in sources ],
include_dirs=[ str(inc) for inc in include_dirs ],
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[ str(lib) for lib in libraries ],
library_dirs=[ str(lib_dir) for lib_dir in library_dirs ],
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
......@@ -18,11 +18,12 @@ def te_version() -> str:
root_path = Path(__file__).resolve().parent
with open(root_path / "VERSION.txt", "r") as f:
version = f.readline().strip()
if (not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0"))
and not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0")))):
if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) and not bool(
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
):
try:
output = subprocess.run(
["git", "rev-parse" , "--short", "HEAD"],
["git", "rev-parse", "--short", "HEAD"],
capture_output=True,
cwd=root_path,
check=True,
......
......@@ -174,7 +174,7 @@ def cuda_version() -> Tuple[int, ...]:
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split('.')
version = match.group(1).split(".")
return tuple(int(v) for v in version)
......@@ -224,9 +224,7 @@ def get_frameworks() -> List[str]:
_frameworks = [framework.lower() for framework in _frameworks]
for framework in _frameworks:
if framework not in supported_frameworks:
raise ValueError(
f"Transformer Engine does not support framework={framework}"
)
raise ValueError(f"Transformer Engine does not support framework={framework}")
return _frameworks
......@@ -242,8 +240,8 @@ def package_files(directory):
def copy_common_headers(te_src, dst):
headers = te_src / "common"
for file_path in glob.glob(os.path.join(str(headers), "**", '*.h'), recursive=True):
new_path = os.path.join(dst, file_path[len(str(te_src)) + 1:])
for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True):
new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :])
Path(new_path).parent.mkdir(exist_ok=True, parents=True)
shutil.copy(file_path, new_path)
......@@ -251,9 +249,10 @@ def copy_common_headers(te_src, dst):
def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
import importlib
try:
importlib.import_module(package)
except ImportError:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
finally:
globals()[package] = importlib.import_module(package)
......@@ -28,21 +28,26 @@ if current_year == release_year:
else:
copyright_year = str(release_year) + "-" + str(current_year)
project = u'Transformer Engine'
copyright = u'{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.'.format(copyright_year)
author = u'NVIDIA CORPORATION & AFFILIATES'
project = "Transformer Engine"
copyright = "{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.".format(copyright_year)
author = "NVIDIA CORPORATION & AFFILIATES"
git_sha = os.getenv("GIT_SHA")
if not git_sha:
try:
git_sha = subprocess.check_output(["git", "log", "--pretty=format:'%h'", "-n1"]).decode('ascii').replace("'","").strip()
git_sha = (
subprocess.check_output(["git", "log", "--pretty=format:'%h'", "-n1"])
.decode("ascii")
.replace("'", "")
.strip()
)
except:
git_sha = u'0000000'
git_sha = "0000000"
git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha
version = str(te_version + u"-" + git_sha)
version = str(te_version + "-" + git_sha)
release = te_version
# hack: version is used for html creation, so put the version picker
......@@ -51,58 +56,60 @@ option_on = " selected"
option_off = ""
release_opt = option_on
option_nr = 0
version = version + """<br/>
version = (
version
+ """<br/>
Version select: <select onChange="window.location.href = this.value" onFocus="this.selectedIndex = {0}">
<option value="https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html"{1}>Current release</option>
<option value="https://docs.nvidia.com/deeplearning/transformer-engine/documentation-archive.html">Older releases</option>
</select>""".format(option_nr, release_opt)
</select>""".format(
option_nr, release_opt
)
)
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.ifconfig',
'nbsphinx',
'breathe',
'autoapi.extension',
"sphinx.ext.autodoc",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx.ext.ifconfig",
"nbsphinx",
"breathe",
"autoapi.extension",
]
templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
source_suffix = '.rst'
source_suffix = ".rst"
master_doc = 'index'
pygments_style = 'sphinx'
master_doc = "index"
pygments_style = "sphinx"
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
html_static_path = ['_static']
html_static_path = ["_static"]
html_show_sphinx = False
html_css_files = [
'css/nvidia_font.css',
'css/nvidia_footer.css',
"css/nvidia_font.css",
"css/nvidia_footer.css",
]
html_theme_options = {
'display_version': True,
'collapse_navigation': False,
'logo_only': False
}
html_theme_options = {"display_version": True, "collapse_navigation": False, "logo_only": False}
napoleon_custom_sections = [('Parallelism parameters', 'params_style'),
('Optimization parameters', 'params_style'),
('Values', 'params_style')]
napoleon_custom_sections = [
("Parallelism parameters", "params_style"),
("Optimization parameters", "params_style"),
("Values", "params_style"),
]
breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")}
breathe_default_project = "TransformerEngine"
......
......@@ -4,8 +4,8 @@
import os
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention
......@@ -18,87 +18,105 @@ _cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
qkv_layout: str,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
dtype: torch.dtype,
config: ModelConfig,
qkv_layout: str,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
reset_rng_states()
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
inp = torch.randn([config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
dtype=dtype, device="cuda")
q = inp[:,:,0,:,:]
k = inp[:,:,1,:,:]
v = inp[:,:,2,:,:]
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
inp = torch.randn(
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
dtype=dtype,
device="cuda",
)
q = inp[:, :, 0, :, :]
k = inp[:, :, 1, :, :]
v = inp[:, :, 2, :, :]
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
out_grad = torch.randn([config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
dtype=dtype, device="cuda")
out_grad = torch.randn(
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
dtype=dtype,
device="cuda",
)
# Create attention mask / bias
attention_mask = None
bias = None
if config.attn_mask_type == "arbitrary":
attention_mask = torch.randint(-10,10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
dtype=torch.bool, device="cuda")
attention_mask = torch.randint(
-10,
10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv],
).to(dtype=torch.bool, device="cuda")
if config.attn_bias_type == "post_scale_bias":
# convert mask to bias
attention_mask = torch.randint(-10,10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
dtype=torch.bool, device="cuda")
attention_mask = torch.randint(
-10,
10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv],
).to(dtype=torch.bool, device="cuda")
bias = attention_mask.clone()
neginf = -2**50 if dtype == torch.bfloat16 else -2**15
bias = torch.where(bias==0, 0, neginf).to(dtype=dtype, device='cuda')
neginf = -(2**50) if dtype == torch.bfloat16 else -(2**15)
bias = torch.where(bias == 0, 0, neginf).to(dtype=dtype, device="cuda")
bias.requires_grad = False
attention_mask = None
block = (
DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
qkv_format='bshd',
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
).to(dtype=dtype, device="cuda")
)
block = DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
qkv_format="bshd",
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
).to(dtype=dtype, device="cuda")
# Run a forward and backward pass
out = None
if config.attn_mask_type == "arbitrary":
out = block(q, k, v,
attention_mask=attention_mask, # attention_mask
qkv_format='bshd',
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
)
out = block(
q,
k,
v,
attention_mask=attention_mask, # attention_mask
qkv_format="bshd",
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
)
out.backward(out_grad)
if config.attn_bias_type == "post_scale_bias":
out = block(q, k, v,
attention_mask=attention_mask, # None
qkv_format='bshd',
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
)
out = block(
q,
k,
v,
attention_mask=attention_mask, # None
qkv_format="bshd",
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
)
out.backward(out_grad)
return out, (q.grad, k.grad, v.grad)
......@@ -107,19 +125,19 @@ def _run_dot_product_attention(
dtype = torch.bfloat16
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"),
"test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
"test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"),
"test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
}
print('Run with post_scale_bias:')
print("Run with post_scale_bias:")
config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
print('Run with arbitrary mask:')
print("Run with arbitrary mask:")
config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
print('Test passed!')
print("Test passed!")
......@@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import (
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention
_run_dot_product_attention,
)
# data type
......@@ -26,7 +26,7 @@ ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
......@@ -36,12 +36,13 @@ is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def example_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
......@@ -51,40 +52,58 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
if fused_attn_supported:
print()
print('Run cuDNN attention...')
print("Run cuDNN attention...")
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if flash_attn_supported:
print()
print('Run flash-attention...')
print("Run flash-attention...")
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
print()
print('Test passed.')
print("Test passed.")
def main():
models = ['test_0']
models = ["test_0"]
for model in models:
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
config,
dtype,
qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
example_attention(model, fused_attn_supported, flash_attn_supported)
if __name__ == "__main__":
main()
......@@ -8,14 +8,15 @@ import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type
def speedometer(
module: torch.nn.Module,
input: torch.Tensor,
output_grad: torch.Tensor,
forward_kwargs: dict = {},
fp8_autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
module: torch.nn.Module,
input: torch.Tensor,
output_grad: torch.Tensor,
forward_kwargs: dict = {},
fp8_autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
) -> None:
"""Measure average run time for a PyTorch module
......@@ -24,7 +25,7 @@ def speedometer(
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if fp8_autocast_kwargs is None:
fp8_autocast_kwargs = { "enabled": False }
fp8_autocast_kwargs = {"enabled": False}
# Warmup runs
torch.cuda.synchronize()
......@@ -51,11 +52,12 @@ class DotProductAttention(torch.nn.Module):
Built with plain PyTorch modules.
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
) -> None:
super().__init__()
self.projection_size = kv_channels * num_attention_heads
......@@ -63,21 +65,17 @@ class DotProductAttention(torch.nn.Module):
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.dropout = torch.nn.Dropout(attention_dropout)
def masked_softmax(
self,
inp: torch.Tensor,
mask: Optional[torch.Tensor]
) -> torch.Tensor:
def masked_softmax(self, inp: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
if mask is not None:
inp.masked_fill_(mask, -10000.0)
return torch.nn.Softmax(dim=-1)(inp)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
b = query.size(1)
np = query.size(2)
......@@ -90,7 +88,9 @@ class DotProductAttention(torch.nn.Module):
# [sk, b, np, hn] -> [sk, b * np, hn]
key = key.view(sk, b * np, -1)
bmm1 = torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor
bmm1 = (
torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor
)
# change view to [b, np, sq, sk]
attention_scores = bmm1.view(b, np, sq, sk)
......@@ -126,10 +126,11 @@ class BasicMLP(torch.nn.Module):
Built with plain PyTorch modules.
"""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
self,
hidden_size: int,
ffn_hidden_size: int,
) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
......@@ -137,7 +138,7 @@ class BasicMLP(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate='tanh')
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
......@@ -148,7 +149,7 @@ def share_parameters_with_basic_te_model(te_model, basic_model):
Parameter values are copied from pure PyTorch implementation.
"""
te_model.ln1.weight= basic_model.ln1.weight
te_model.ln1.weight = basic_model.ln1.weight
te_model.ln1.bias = basic_model.ln1.bias
te_model.qkv_projection.weight = basic_model.qkv_projection.weight
te_model.qkv_projection.bias = basic_model.qkv_projection.bias
......@@ -202,14 +203,15 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):
te_model.layernorm_mlp.fc2_bias = basic_model.mlp.linear2.bias
def cast_to_representable(inp, scale = 1., fp8_format='e4m3'):
def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"):
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
fp8_type = tex.DType.kFloat8E4M3 if fp8_format == 'e4m3' else tex.DType.kFloat8E5M2
fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2
input_type = TE_DType[inp.dtype]
meta = tex.FP8TensorMeta()
meta.scale = torch.ones(1,dtype=torch.float32, device="cuda") * scale
meta.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale
meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
ret = texcpp.cast_to_fp8(inp, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type)
......
......@@ -15,11 +15,17 @@ from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init
import transformers
from transformers.models.llama.modeling_llama import LlamaModel, LlamaForCausalLM, LlamaRMSNorm, LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaModel,
LlamaForCausalLM,
LlamaRMSNorm,
LlamaConfig,
)
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files
@contextmanager
def replace_decoder(te_decoder_cls):
"""
......@@ -43,6 +49,7 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
args: positional args (for compatibility with `LlamaDecoderLayer`)
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
"""
def __init__(self, config, *args, **kwargs):
super().__init__(
hidden_size=config.hidden_size,
......@@ -56,22 +63,22 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads
num_gqa_groups=config.num_key_value_heads,
)
te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
def forward(self,
hidden_states,
*args,
attention_mask,
**kwargs):
def forward(self, hidden_states, *args, attention_mask, **kwargs):
"""
Custom forward to make sure we only pass relevant arguments to the
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)
return (
super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
),
)
class TELlamaForCausalLM:
......@@ -95,21 +102,29 @@ class TELlamaForCausalLM:
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
vanilla_model = cls(config).to(kwargs['torch_dtype'])
vanilla_model = cls(config).to(kwargs["torch_dtype"])
is_local = os.path.isdir(pretrained_model_name_or_path)
subfolder = ""
variant = None
if os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant))
):
os.path.join(
pretrained_model_name_or_path,
subfolder,
_add_variant("model.safetensors.index.json", variant),
)
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant)
pretrained_model_name_or_path,
subfolder,
_add_variant("model.safetensors.index.json", variant),
)
is_sharded = True
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
......@@ -118,10 +133,9 @@ class TELlamaForCausalLM:
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
archive_file,
pretrained_model_name_or_path,
archive_file,
)
# If the checkpoint is not sharded, it's a trivial sharding case
......@@ -142,48 +156,63 @@ class TELlamaForCausalLM:
return vanilla_model
def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = 'model.layers.\d+.'
layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in HF model
if layer_prefix + 'input_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]
if layer_prefix + "input_layernorm.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[
:
] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
)
if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]
if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
)
if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]
if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
)
if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:]
if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[
layer_prefix + "self_attn.o_proj.weight"
].data[:]
if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:]
if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[
layer_prefix + "post_attention_layernorm.weight"
].data[:]
if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]
# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \
hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data
if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \
hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data
if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]
return all_layer_prefixes
\ No newline at end of file
if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
: config.intermediate_size
] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
config.intermediate_size :
] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[
layer_prefix + "mlp.down_proj.weight"
].data[:]
return all_layer_prefixes
......@@ -10,29 +10,36 @@ import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, AutoConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
get_linear_schedule_with_warmup,
AutoConfig,
)
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
from accelerate import Accelerator
from accelerate.utils.dataclasses import FP8RecipeKwargs
class HyperParameters:
def __init__(self):
self.mixed_precision = "bf16"
#self.model_name = "" # <== Add model weight location here
# self.model_name = "" # <== Add model weight location here
self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text"
self.learning_rate = 1.41e-5
self.batch_size = 8
self.max_seq_length = 256
self.gradient_accumulation_steps = 1
self.num_warmup_steps=5
self.num_training_steps=10
self.num_warmup_steps = 5
self.num_training_steps = 10
hyperparams = HyperParameters()
def get_dataloaders(accelerator:Accelerator, hyperparams):
def get_dataloaders(accelerator: Accelerator, hyperparams):
dataset = load_dataset(hyperparams.dataset_name, split="train")
tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
if getattr(tokenizer, "pad_token", None) is None:
......@@ -45,16 +52,12 @@ def get_dataloaders(accelerator:Accelerator, hyperparams):
padding=False,
max_length=hyperparams.max_seq_length,
return_overflowing_tokens=False,
return_length=False
return_length=False,
)
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
with accelerator.main_process_first():
dataset = dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names
)
dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
# Simply pad to the multiple of 16 for both FP8 and BF16 precision
pad_to_multiple_of = 16
......@@ -72,6 +75,7 @@ def get_dataloaders(accelerator:Accelerator, hyperparams):
train_dataloader = DataLoader(dataset, **dataloader_params)
return train_dataloader
def init_baseline_model(hyperparams):
# Init the model
config = AutoConfig.from_pretrained(hyperparams.model_name)
......@@ -84,42 +88,47 @@ def init_baseline_model(hyperparams):
)
model = model.cuda()
# Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison
model.config.use_cache=False
model.config.use_cache = False
return model
def init_te_llama_model(hyperparams):
# Init the model
from te_llama import TELlamaForCausalLM
config = AutoConfig.from_pretrained(hyperparams.model_name)
config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local(
hyperparams.model_name,
config=config,
torch_dtype=torch.bfloat16,
hyperparams.model_name,
config=config,
torch_dtype=torch.bfloat16,
)
model = model.cuda()
# Needed for the cases when using TELlamaForCausalLM
model.config.use_cache=False
model.config.use_cache = False
return model
def wrap_with_accelerator(model, hyperparams):
# Create FP8 kwarg handler if required
fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None
fp8_kwarg_handler = (
[FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None
)
# Init HF accelerator that's used for training
accelerator = Accelerator(
log_with="wandb",
gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,
mixed_precision=hyperparams.mixed_precision,
kwargs_handlers=fp8_kwarg_handler
kwargs_handlers=fp8_kwarg_handler,
)
#accelerator.print(f'State: {accelerator.state}')
# accelerator.print(f'State: {accelerator.state}')
train_dataloader = get_dataloaders(accelerator, hyperparams)
# Wrap model, optimizer/scheduler, dataloaders in accelerate
optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True)
optimizer = AdamW(params=model.parameters(), lr=hyperparams.learning_rate, fused=True)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
......@@ -131,6 +140,7 @@ def wrap_with_accelerator(model, hyperparams):
return accelerator, model, optimizer, train_dataloader, lr_scheduler
def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler):
model.train()
total_loss = 0
......@@ -170,7 +180,11 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer,
end.record()
accelerator.end_training()
print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds")
print(
f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step:"
f" {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds"
)
def restart_jupyter_notebook():
# Try restarting the Jupyter kernel
......@@ -179,18 +193,23 @@ def restart_jupyter_notebook():
# Check whether the device memory has been flushed
if torch.cuda.memory_allocated() != 0:
import warnings
warnings.warn("The device memory hasn't been flushed, trying with a second method!")
# Try restarting the Jupyter kernel another way
# Restart the kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
if torch.cuda.memory_allocated() != 0:
print("The device memory hasn't been flushed, try manually restarting the Jupyter kernel!")
print(
"The device memory hasn't been flushed, try manually restarting the Jupyter kernel!"
)
# Suppress the warnings
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
torch.set_warn_always(False)
......@@ -22,18 +22,19 @@ from jax.experimental.pjit import pjit
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
DEVICE_DP_AXIS = 'data'
DEVICE_TP_AXIS = 'model'
NAMED_BROADCAST_AXIS = 'my_broadcast_axis'
NAMED_TP_AXIS = 'my_tp_axis'
PARAMS_KEY = 'params'
PARAMS_AXES_KEY = PARAMS_KEY + '_axes'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
enable_seq_paral: bool
......@@ -41,36 +42,43 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16)
te_Encoder = partial(
te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
x = jax.lax.with_sharding_constraint(x,
jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None))
x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16)(x)
x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
......@@ -98,20 +106,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch."""
train_ds_size = len(train_ds['sentence'])
train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_inputs = train_ds['sentence'][perm, ...]
batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs)
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
state, loss, accuracy, var_collect = train_fn(
state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
......@@ -137,7 +146,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop."""
test_ds_size = len(test_ds['sentence'])
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
......@@ -145,9 +154,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_inputs = test_ds['sentence'][batch_start:batch_end]
batch_masks = test_ds['mask'][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end]
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -159,12 +168,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
nltk.download("punkt")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
......@@ -184,9 +193,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
'sentence': output,
'label': dataset['label'].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
"sentence": output,
"label": dataset["label"].astype(np.float32),
"mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
......@@ -196,12 +205,12 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......@@ -210,7 +219,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def get_params_pspec(sharding_rules, abs_var_collect):
......@@ -255,8 +265,9 @@ def train_and_evaluate(args):
num_gpu_tp = 1
assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
assert args.test_batch_size % num_gpu_dp == 0, \
f"Test batch size needs to be multiple of {num_gpu_dp}"
assert (
args.test_batch_size % num_gpu_dp == 0
), f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
......@@ -270,9 +281,9 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)):
with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
):
encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
......@@ -285,18 +296,21 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect}
out_shardings = {
key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params,
tx=optimizer)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
state_pspec = get_state_pspec(state, params_pspec)
labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,)
labels_pspec = jax.sharding.PartitionSpec(
DEVICE_DP_AXIS,
)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
......@@ -323,16 +337,20 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step)
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, pjit_eval_step
)
print(f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ")
print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy]
......@@ -382,14 +400,15 @@ def encoder_parser(args):
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
parser.add_argument("--enable-sp",
action="store_true",
default=False,
help="Enable sequence parallelism.")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
return parser.parse_args(args)
......
......@@ -22,32 +22,35 @@ from jax.experimental.pjit import pjit
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
DEVICE_DP_AXIS = 'data'
PARAMS_KEY = 'params'
PARAMS_AXES_KEY = PARAMS_KEY + '_axes'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
dtype=jnp.bfloat16)
te_Encoder = partial(
te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
......@@ -82,20 +85,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch."""
train_ds_size = len(train_ds['sentence'])
train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_inputs = train_ds['sentence'][perm, ...]
batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs)
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
state, loss, accuracy, var_collect = train_fn(
state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
......@@ -121,7 +125,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop."""
test_ds_size = len(test_ds['sentence'])
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
......@@ -129,9 +133,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_inputs = test_ds['sentence'][batch_start:batch_end]
batch_masks = test_ds['mask'][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end]
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -143,12 +147,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
nltk.download("punkt")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
......@@ -168,9 +172,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
'sentence': output,
'label': dataset['label'].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
"sentence": output,
"label": dataset["label"].astype(np.float32),
"mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
......@@ -180,12 +184,12 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......@@ -194,7 +198,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def get_params_pspec(sharding_rules, abs_var_collect):
......@@ -232,8 +237,7 @@ def train_and_evaluate(args):
num_gpu = jax.local_device_count()
assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
assert args.test_batch_size % num_gpu == 0, \
f"Test batch size needs to be multiple of {num_gpu}"
assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)):
......@@ -247,8 +251,9 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)):
with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)
):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
......@@ -260,18 +265,21 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect}
out_shardings = {
key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params,
tx=optimizer)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
state_pspec = get_state_pspec(state, params_pspec)
labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,)
labels_pspec = jax.sharding.PartitionSpec(
DEVICE_DP_AXIS,
)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
......@@ -298,16 +306,20 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step)
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, pjit_eval_step
)
print(f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ")
print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy]
......@@ -357,10 +369,12 @@ def encoder_parser(args):
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
return parser.parse_args(args)
......
......@@ -19,30 +19,33 @@ from flax.training import train_state
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
PARAMS_KEY = 'params'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
dtype=jnp.bfloat16)
te_Encoder = partial(
te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
......@@ -78,20 +81,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch."""
train_ds_size = len(train_ds['sentence'])
train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_inputs = train_ds['sentence'][perm, ...]
batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs)
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
state, loss, accuracy, var_collect = train_step(
state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
......@@ -118,7 +122,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop."""
test_ds_size = len(test_ds['sentence'])
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
......@@ -126,9 +130,9 @@ def eval_model(state, test_ds, batch_size, var_collect):
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_inputs = test_ds['sentence'][batch_start:batch_end]
batch_masks = test_ds['mask'][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end]
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -140,12 +144,12 @@ def eval_model(state, test_ds, batch_size, var_collect):
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
nltk.download("punkt")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
......@@ -165,9 +169,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
'sentence': output,
'label': dataset['label'].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
"sentence": output,
"label": dataset["label"].astype(np.float32),
"mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
......@@ -177,12 +181,12 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......@@ -191,7 +195,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def train_and_evaluate(args):
......@@ -214,9 +219,9 @@ def train_and_evaluate(args):
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
var_collect = encoder.init(init_rngs, inputs, masks)
tx = optax.adamw(args.lr)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=var_collect[PARAMS_KEY],
tx=tx)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=var_collect[PARAMS_KEY], tx=tx
)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
......@@ -235,15 +240,18 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect)
state, train_ds, args.batch_size, rngs, var_collect
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
print(f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ")
print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy]
......@@ -293,10 +301,12 @@ def encoder_parser(args):
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
return parser.parse_args(args)
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" MNIST training on single GPU"""
"""MNIST training on single GPU"""
import argparse
import unittest
from functools import partial
......@@ -20,13 +20,14 @@ import transformer_engine.jax.flax as te_flax
IMAGE_H = 28
IMAGE_W = 28
IMAGE_C = 1
PARAMS_KEY = 'params'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
class Net(nn.Module):
"""CNN model for MNIST."""
use_te: bool = False
@nn.compact
......@@ -83,17 +84,17 @@ def update_model(state, grads):
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
train_ds_size = len(train_ds["image"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = train_ds['image'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
batch_images = train_ds["image"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs)
state, var_collect = update_model(state, grads)
epoch_loss.append(loss)
......@@ -106,7 +107,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect):
def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop."""
test_ds_size = len(test_ds['image'])
test_ds_size = len(test_ds["image"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
......@@ -114,8 +115,8 @@ def eval_model(state, test_ds, batch_size, var_collect):
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_images = test_ds['image'][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end]
batch_images = test_ds["image"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
_, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -127,21 +128,21 @@ def eval_model(state, test_ds, batch_size, var_collect):
def get_datasets():
"""Load MNIST train and test datasets into memory."""
train_ds = load_dataset('mnist', split='train')
train_ds.set_format(type='np')
batch_size = train_ds['image'].shape[0]
train_ds = load_dataset("mnist", split="train")
train_ds.set_format(type="np")
batch_size = train_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_train_ds = {
'image': train_ds['image'].astype(np.float32).reshape(shape) / 255.,
'label': train_ds['label']
"image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0,
"label": train_ds["label"],
}
test_ds = load_dataset('mnist', split='test')
test_ds.set_format(type='np')
batch_size = test_ds['image'].shape[0]
test_ds = load_dataset("mnist", split="test")
test_ds.set_format(type="np")
batch_size = test_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_test_ds = {
'image': test_ds['image'].astype(np.float32).reshape(shape) / 255.,
'label': test_ds['label']
"image": test_ds["image"].astype(np.float32).reshape(shape) / 255.0,
"label": test_ds["label"],
}
return new_train_ds, new_test_ds
......@@ -149,8 +150,13 @@ def get_datasets():
def check_fp8(state, var_collect, input_shape, label_shape):
"Check if model includes FP8."
assert "f8_" in str(
jax.make_jaxpr(apply_model)(state, jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect))
jax.make_jaxpr(apply_model)(
state,
jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16),
var_collect,
)
)
def train_and_evaluate(args):
......@@ -173,17 +179,21 @@ def train_and_evaluate(args):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum)
state = train_state.TrainState.create(apply_fn=cnn.apply,
params=var_collect[PARAMS_KEY],
tx=tx)
state = train_state.TrainState.create(
apply_fn=cnn.apply, params=var_collect[PARAMS_KEY], tx=tx
)
if args.use_fp8:
check_fp8(state, var_collect, input_shape, label_shape)
if args.dry_run:
apply_model(state, jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect,
{DROPOUT_KEY: dropout_rng})
apply_model(
state,
jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16),
var_collect,
{DROPOUT_KEY: dropout_rng},
)
print("PASSED")
return None
......@@ -193,14 +203,17 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect)
state, train_ds, args.batch_size, rngs, var_collect
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
print(f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ")
print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy]
......@@ -250,15 +263,18 @@ def mnist_parser(args):
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration. " \
"It also enables Transformer Engine implicitly.")
parser.add_argument("--use-te",
action="store_true",
default=False,
help="Use Transformer Engine")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help=(
"Use FP8 for inference and training without recalibration. "
"It also enables Transformer Engine implicitly."
),
)
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
return parser.parse_args(args)
......
......@@ -59,7 +59,9 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8):
model.train()
losses = []
for batch_id, (data, labels) in enumerate(train_loader):
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
with paddle.amp.auto_cast(
dtype="bfloat16", level="O2"
): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=use_fp8):
outputs = model(data)
loss = F.cross_entropy(outputs, labels)
......@@ -70,10 +72,12 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8):
optimizer.clear_gradients()
if batch_id % args.log_interval == 0:
print(f"Train Epoch: {epoch} "
f"[{batch_id * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_id / len(train_loader):.0f}%)]\t"
f"Loss: {loss.item():.6f}")
print(
f"Train Epoch: {epoch} "
f"[{batch_id * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_id / len(train_loader):.0f}%)]\t"
f"Loss: {loss.item():.6f}"
)
if args.dry_run:
return loss.item()
avg_loss = sum(losses) / len(losses)
......@@ -89,7 +93,9 @@ def evaluate(model, test_loader, epoch, use_fp8):
with paddle.no_grad():
for data, labels in test_loader:
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
with paddle.amp.auto_cast(
dtype="bfloat16", level="O2"
): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=use_fp8):
outputs = model(data)
acc = metric.compute(outputs, labels)
......@@ -104,7 +110,9 @@ def calibrate(model, test_loader):
with paddle.no_grad():
for data, _ in test_loader:
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
with paddle.amp.auto_cast(
dtype="bfloat16", level="O2"
): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=False, calibrating=True):
_ = model(data)
......@@ -160,20 +168,27 @@ def mnist_parser(args):
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration. " \
"It also enables Transformer Engine implicitly.")
parser.add_argument("--use-fp8-infer",
action="store_true",
default=False,
help="Use FP8 for inference only. If not using FP8 for training, "
"calibration is performed for FP8 infernece.")
parser.add_argument("--use-te",
action="store_true",
default=False,
help="Use Transformer Engine")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help=(
"Use FP8 for inference and training without recalibration. "
"It also enables Transformer Engine implicitly."
),
)
parser.add_argument(
"--use-fp8-infer",
action="store_true",
default=False,
help=(
"Use FP8 for inference only. If not using FP8 for training, "
"calibration is performed for FP8 infernece."
),
)
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
args = parser.parse_args(args)
return args
......@@ -185,9 +200,9 @@ def train_and_evaluate(args):
paddle.seed(args.seed)
# Load MNIST dataset
transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
train_dataset = MNIST(mode='train', transform=transform)
val_dataset = MNIST(mode='test', transform=transform)
transform = Normalize(mean=[127.5], std=[127.5], data_format="CHW")
train_dataset = MNIST(mode="train", transform=transform)
val_dataset = MNIST(mode="test", transform=transform)
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
......@@ -198,7 +213,7 @@ def train_and_evaluate(args):
optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
# Cast model to BF16
model = paddle.amp.decorate(models=model, level='O2', dtype='bfloat16')
model = paddle.amp.decorate(models=model, level="O2", dtype="bfloat16")
for epoch in range(1, args.epochs + 1):
loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8)
......@@ -209,7 +224,7 @@ def train_and_evaluate(args):
if args.save_model or args.use_fp8_infer:
paddle.save(model.state_dict(), "mnist_cnn.pdparams")
print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8))
print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8))
weights = paddle.load("mnist_cnn.pdparams")
model.set_state_dict(weights)
acc = evaluate(model, val_loader, 0, args.use_fp8)
......@@ -235,8 +250,10 @@ class TestMNIST(unittest.TestCase):
assert actual[0] < desired_traing_loss
assert actual[1] > desired_test_accuracy
@unittest.skipIf(paddle.device.cuda.get_device_capability() < (8, 0),
"BF16 MNIST example requires Ampere+ GPU")
@unittest.skipIf(
paddle.device.cuda.get_device_capability() < (8, 0),
"BF16 MNIST example requires Ampere+ GPU",
)
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
self.args.use_te = True
......
......@@ -15,62 +15,77 @@ import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
def parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers.")
parser.add_argument('-i', "--num-iters", type=int, default=5,
help="Number of dummy 'training' iterations.")
parser.add_argument('-b', "--batch-size", type=int, default=2,
help="Input batch size.")
parser.add_argument('-s', "--seq-length", type=int, default=2048,
help="Input sequence length.")
parser.add_argument('-n', "--num-heads", type=int, default=64,
help="Number of attention heads.")
parser.add_argument('-d', "--head-dim", type=int, default=128,
help="Dimension of each attention head.")
parser.add_argument("--mlp-expansion-factor", type=int, default=4,
help="MLP block intermediate size as a factor of hidden dimension.")
parser.add_argument("--seed", type=int, default=1234,
help="RNG seed.")
parser.add_argument("--fp8", action="store_true", default=False,
help="Enables the te.fp8_autocast() context.")
parser.add_argument("--no-comm-overlap", action="store_true", default=False,
help="Disable the comm+GEMM overlap.")
parser.add_argument('-v', "--verbose", action="store_true", default=False)
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
)
parser.add_argument(
"-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
)
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=64, help="Number of attention heads."
)
parser.add_argument(
"-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
)
parser.add_argument(
"--mlp-expansion-factor",
type=int,
default=4,
help="MLP block intermediate size as a factor of hidden dimension.",
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--no-comm-overlap",
action="store_true",
default=False,
help="Disable the comm+GEMM overlap.",
)
parser.add_argument("-v", "--verbose", action="store_true", default=False)
return parser.parse_args(argv, namespace)
def train(opts):
WORLD_RANK = int(os.getenv("RANK"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE"))
def dist_print(msg, end='\n', all_ranks=False):
def dist_print(msg, end="\n", all_ranks=False):
if WORLD_RANK == 0 or all_ranks:
print(f"[RANK-{WORLD_RANK}] {msg}", end=end)
# Seed RNG
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(opts.seed+WORLD_RANK)
torch.cuda.manual_seed(opts.seed+WORLD_RANK)
torch.manual_seed(opts.seed + WORLD_RANK)
torch.cuda.manual_seed(opts.seed + WORLD_RANK)
# Initialize torch.distributed global process group and get TP group
dist.init_process_group(backend="nccl",
rank=WORLD_RANK,
world_size=WORLD_SIZE,
device_id=torch.device(f'cuda:{WORLD_RANK}'))
dist.init_process_group(
backend="nccl",
rank=WORLD_RANK,
world_size=WORLD_SIZE,
device_id=torch.device(f"cuda:{WORLD_RANK}"),
)
tp_group = dist.new_group(backend="nccl")
tp_size = dist.get_world_size(tp_group)
# Intialize userbuffers
ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad
'method': 'ring_exchange',
'num_splits' : 8,
'num_sm' : 1,
'set_sm_margin' : False,
"method": "ring_exchange",
"num_splits": 8,
"num_sm": 1,
"set_sm_margin": False,
}
rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop
'method': 'ring_exchange',
'num_splits' : 4,
'num_sm' : 1,
'set_sm_margin' : True,
"method": "ring_exchange",
"num_splits": 4,
"num_sm": 1,
"set_sm_margin": True,
}
hidden_size = opts.num_heads * opts.head_dim
batched_size = opts.seq_length * opts.batch_size
......@@ -78,30 +93,31 @@ def train(opts):
te.initialize_ub(
[batched_size, hidden_size],
tp_group,
use_fp8 = opts.fp8,
dtype = torch.bfloat16,
ub_cfgs = {
'fc1_fprop': ag_cfg,
'fc1_dgrad': rs_cfg,
'fc2_fprop': rs_cfg,
'fc2_dgrad': ag_cfg,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
ub_cfgs={
"fc1_fprop": ag_cfg,
"fc1_dgrad": rs_cfg,
"fc2_fprop": rs_cfg,
"fc2_dgrad": ag_cfg,
},
)
#
model = te.LayerNormMLP(
hidden_size, opts.mlp_expansion_factor * hidden_size,
params_dtype = torch.bfloat16,
device = 'cuda',
tp_group = tp_group,
tp_size = tp_size,
set_parallel_mode = True,
sequence_parallel = True, # this is required for comm+GEMM overlap
seq_length = opts.seq_length,
micro_batch_size = opts.batch_size,
ub_overlap_rs_dgrad = not opts.no_comm_overlap,
ub_overlap_rs = not opts.no_comm_overlap,
ub_overlap_ag = not opts.no_comm_overlap,
hidden_size,
opts.mlp_expansion_factor * hidden_size,
params_dtype=torch.bfloat16,
device="cuda",
tp_group=tp_group,
tp_size=tp_size,
set_parallel_mode=True,
sequence_parallel=True, # this is required for comm+GEMM overlap
seq_length=opts.seq_length,
micro_batch_size=opts.batch_size,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_overlap_rs=not opts.no_comm_overlap,
ub_overlap_ag=not opts.no_comm_overlap,
)
# Initialize optimizer with model parameters
......@@ -109,16 +125,19 @@ def train(opts):
# Fp8 recipe setup
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32,
amax_compute_algo="max")
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Start dummy "training" iterations
for i in range(opts.num_iters):
dist_print(f"Iter {i+1}", all_ranks=opts.verbose)
dist_print("|-- Generate random input batch", all_ranks=opts.verbose)
x = torch.rand((opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16, device='cuda', requires_grad=True)
x = torch.rand(
(opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
dist_print("|-- Forward pass", all_ranks=opts.verbose)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group):
......@@ -135,17 +154,15 @@ def train(opts):
te.destroy_ub()
dist.destroy_process_group()
if __name__ == "__main__":
if "TORCHELASTIC_RUN_ID" in os.environ.keys():
args = parse_args()
train(args)
else:
subprocess.run(
[
'torchrun', f'--nproc-per-node={torch.cuda.device_count()}',
*sys.argv
],
["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv],
env=os.environ,
check=True
check=True,
)
os._exit(0)
......@@ -14,7 +14,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper
checkpoint_wrapper,
)
import transformer_engine.pytorch as te
......@@ -29,46 +29,56 @@ rng_seed = 1234
torch.manual_seed(rng_seed)
torch.cuda.manual_seed(rng_seed)
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
CUDA_RNG_STATES_TRACKER.add('model-parallel-rng', rng_seed)
CUDA_RNG_STATES_TRACKER.add("model-parallel-rng", rng_seed)
def get_cuda_rng_tracker():
return CUDA_RNG_STATES_TRACKER
def apply_fsdp_checkpointing(model, blocks):
"""apply activation checkpointing to model
returns None as model is updated directly
"""
wrapper = lambda m: checkpoint_wrapper(m,
checkpoint_fn=te.distributed.checkpoint,
use_reentrant=False,
get_rng_state_tracker=get_cuda_rng_tracker)
wrapper = lambda m: checkpoint_wrapper(
m,
checkpoint_fn=te.distributed.checkpoint,
use_reentrant=False,
get_rng_state_tracker=get_cuda_rng_tracker,
)
check_fn = lambda submodule: isinstance(submodule, blocks)
apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
def lowercase(s):
return str(s).lower()
def torch_dtype(d):
typemap = {
'fp32' : torch.float32,
'float32' : torch.float32,
'fp16' : torch.float16,
'float16' : torch.float16,
'bf16' : torch.bfloat16,
'bfloat16' : torch.bfloat16
"fp32": torch.float32,
"float32": torch.float32,
"fp16": torch.float16,
"float16": torch.float16,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
}
if lowercase(d) not in typemap.keys():
raise TypeError
return typemap[lowercase(d)]
te_layer_map = {
'linear': te.Linear,
'layernorm': te.LayerNorm,
'rmsnorm': te.RMSNorm,
'layernormlinear': te.LayerNormLinear,
'layernormmlp': te.LayerNormMLP,
'multiheadattention': te.MultiheadAttention,
'transformerlayer': te.TransformerLayer
"linear": te.Linear,
"layernorm": te.LayerNorm,
"rmsnorm": te.RMSNorm,
"layernormlinear": te.LayerNormLinear,
"layernormmlp": te.LayerNormMLP,
"multiheadattention": te.MultiheadAttention,
"transformerlayer": te.TransformerLayer,
}
def te_layer(l):
if l is not None:
if lowercase(l) not in te_layer_map.keys():
......@@ -76,74 +86,120 @@ def te_layer(l):
return te_layer_map[lowercase(l)]
return None
def get_layer_args(opts):
hidden_size = opts.num_heads * opts.head_dim
layer_args = (hidden_size, )
layer_args = (hidden_size,)
layer_kwargs = {
'params_dtype': opts.dtype,
'device': 'cuda' if opts.no_defer_init else 'meta',
'get_rng_state_tracker': get_cuda_rng_tracker,
"params_dtype": opts.dtype,
"device": "cuda" if opts.no_defer_init else "meta",
"get_rng_state_tracker": get_cuda_rng_tracker,
}
if opts.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 3 * hidden_size if opts.num_layers == 1 else hidden_size
layer_args += (ffn_hidden_size, )
layer_kwargs['bias'] = True
layer_args += (ffn_hidden_size,)
layer_kwargs["bias"] = True
if opts.layer_type == te.LayerNormMLP:
layer_kwargs['seq_length'] = opts.seq_length
layer_kwargs["seq_length"] = opts.seq_length
elif opts.layer_type == te.MultiheadAttention:
layer_args += (opts.num_heads, )
layer_kwargs['fuse_qkv_params'] = True
layer_kwargs['input_layernorm'] = True
layer_args += (opts.num_heads,)
layer_kwargs["fuse_qkv_params"] = True
layer_kwargs["input_layernorm"] = True
elif opts.layer_type == te.TransformerLayer:
layer_args += (3 * hidden_size, opts.num_heads)
layer_kwargs['fuse_qkv_params'] = True
layer_kwargs['seq_length'] = opts.seq_length
layer_kwargs["fuse_qkv_params"] = True
layer_kwargs["seq_length"] = opts.seq_length
return layer_args, layer_kwargs
def parse_fsdp_args():
parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " +
"torch.distributed.fsdp.FullyShardedDataParallel strategy.")
parser.add_argument('-v', "--verbose", action="store_true", default=False,
help="Print out information from all GPUs instead of only the root GPU-0.")
parser.add_argument('-b', "--batch-size", type=int, default=32,
help="Input batch size.")
parser.add_argument('-s', "--seq-length", type=int, default=1048,
help="Input sequence length.")
parser.add_argument('-n', "--num-heads", type=int, default=16,
help="Number of attention heads.")
parser.add_argument('-d', "--head-dim", type=int, default=128,
help="Dimension of each attention head (number of KV channels).")
parser.add_argument('-i', "--num-iters", type=int, default=5,
help="Number of dummy 'training' iterations.")
parser.add_argument('-k', "--num-layers", type=int, default=3,
help="Number of modules chained together with nn.Sequential.")
parser.add_argument("--layer-type", type=te_layer, default=te.TransformerLayer,
choices=list(te_layer_map.values()),
help="TE module type used to construct the test model.")
parser.add_argument("--seed", type=int, default=1234,
help="PyTorch RNG seed.")
parser.add_argument("--profile-memory", action="store_true",
help="Enable memory profiling via torch.profiler.profile().")
parser.add_argument("--profile-name", type=str, default=None,
help="File path for memory profiling.")
parser.add_argument("--checkpoint-layer", type=te_layer, default=None,
help="Recompute activations of the selected layer during the backward " + \
"pass instead of saving.")
parser.add_argument("--no-fp8", action="store_true", default=False,
help="Disables the te.fp8_autocast() context.")
parser.add_argument("--no-defer-init", action="store_true",
help="Defer module parameter initialization until after FSDP sharding.")
parser.add_argument("--no-te-fsdp", action="store_true",
help="Disable sharding of intermediate/activation tensors in TE modules.")
parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.")
parser = argparse.ArgumentParser(
description="Run Transformer Engine modules with the "
+ "torch.distributed.fsdp.FullyShardedDataParallel strategy."
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="Print out information from all GPUs instead of only the root GPU-0.",
)
parser.add_argument("-b", "--batch-size", type=int, default=32, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=1048, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=16, help="Number of attention heads."
)
parser.add_argument(
"-d",
"--head-dim",
type=int,
default=128,
help="Dimension of each attention head (number of KV channels).",
)
parser.add_argument(
"-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
)
parser.add_argument(
"-k",
"--num-layers",
type=int,
default=3,
help="Number of modules chained together with nn.Sequential.",
)
parser.add_argument(
"--layer-type",
type=te_layer,
default=te.TransformerLayer,
choices=list(te_layer_map.values()),
help="TE module type used to construct the test model.",
)
parser.add_argument("--seed", type=int, default=1234, help="PyTorch RNG seed.")
parser.add_argument(
"--profile-memory",
action="store_true",
help="Enable memory profiling via torch.profiler.profile().",
)
parser.add_argument(
"--profile-name", type=str, default=None, help="File path for memory profiling."
)
parser.add_argument(
"--checkpoint-layer",
type=te_layer,
default=None,
help="Recompute activations of the selected layer during the backward "
+ "pass instead of saving.",
)
parser.add_argument(
"--no-fp8",
action="store_true",
default=False,
help="Disables the te.fp8_autocast() context.",
)
parser.add_argument(
"--no-defer-init",
action="store_true",
help="Defer module parameter initialization until after FSDP sharding.",
)
parser.add_argument(
"--no-te-fsdp",
action="store_true",
help="Disable sharding of intermediate/activation tensors in TE modules.",
)
parser.add_argument(
"--dtype",
type=torch_dtype,
default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.",
)
return parser.parse_args()
def dist_print(text, all_ranks=False, no_new_line=False):
if LOCAL_RANK == 0 or all_ranks:
end = '' if no_new_line else '\n'
end = "" if no_new_line else "\n"
print(f"[GPU-{LOCAL_RANK}] " + text, end=end)
def train(opts):
# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
......@@ -157,7 +213,7 @@ def train(opts):
te_layer_list = []
for i in range(opts.num_layers):
if opts.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
layer_kwargs['layer_number'] = i+1
layer_kwargs["layer_number"] = i + 1
te_layer_list.append(opts.layer_type(*layer_args, **layer_kwargs))
te_model = nn.Sequential(*te_layer_list)
else:
......@@ -171,20 +227,23 @@ def train(opts):
# Wrap the model with FSDP
# NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
# controls all communication.
all_gpus = dist.new_group(backend='nccl')
all_gpus = dist.new_group(backend="nccl")
fsdp_wrap_policy = always_wrap_policy
if opts.layer_type == te.TransformerLayer:
# NOTE: FSDP causes illegal memory access without this special policy for Transformers
fsdp_wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={te.TransformerLayer})
te_model = FullyShardedDataParallel(te_model,
process_group=all_gpus,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=opts.dtype,
reduce_dtype=torch.float32,
),
auto_wrap_policy=fsdp_wrap_policy)
fsdp_wrap_policy = partial(
transformer_auto_wrap_policy, transformer_layer_cls={te.TransformerLayer}
)
te_model = FullyShardedDataParallel(
te_model,
process_group=all_gpus,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=opts.dtype,
reduce_dtype=torch.float32,
),
auto_wrap_policy=fsdp_wrap_policy,
)
if opts.checkpoint_layer is not None:
# Recompute the activations of the selected layer during the backward pass instead of
......@@ -218,8 +277,13 @@ def train(opts):
for i in range(opts.num_iters):
# Generate a random input batch
x = torch.rand(opts.seq_length, opts.batch_size, opts.num_heads*opts.head_dim,
dtype=opts.dtype, device='cuda')
x = torch.rand(
opts.seq_length,
opts.batch_size,
opts.num_heads * opts.head_dim,
dtype=opts.dtype,
device="cuda",
)
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
y = te_model(x)
......@@ -230,7 +294,6 @@ def train(opts):
optim.zero_grad(set_to_none=True)
del x
if opts.profile_memory:
torch.cuda.memory._dump_snapshot(f"gpu{LOCAL_RANK}_{opts.profile_name}.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
......@@ -238,7 +301,7 @@ def train(opts):
end.record()
torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated()
train_time = start.elapsed_time(end)/1000.
train_time = start.elapsed_time(end) / 1000.0
dist_print(f"Training Time: {train_time}s")
dist_print(f"Avg. Iter. Time: {train_time / opts.num_iters}s")
dist_print(f"Peak Memory Use: {peak_mem * 1e-6}MBs")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment