Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8): ...@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8):
with te.fp8_autocast(enabled=fp8, calibrating=True): with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data) output = model(data)
def test(model, device, test_loader, use_fp8): def test(model, device, test_loader, use_fp8):
"""Testing function.""" """Testing function."""
model.eval() model.eval()
...@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8): ...@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
output = model(data) output = model(data)
test_loss += F.nll_loss( test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
output, target, reduction="sum" pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item() correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
...@@ -150,9 +147,7 @@ def main(): ...@@ -150,9 +147,7 @@ def main():
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument( parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument( parser.add_argument(
"--log-interval", "--log-interval",
type=int, type=int,
...@@ -167,7 +162,10 @@ def main(): ...@@ -167,7 +162,10 @@ def main():
help="For Saving the current Model", help="For Saving the current Model",
) )
parser.add_argument( parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" "--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
) )
parser.add_argument( parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only" "--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
...@@ -215,7 +213,7 @@ def main(): ...@@ -215,7 +213,7 @@ def main():
if args.save_model or args.use_fp8_infer: if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt") torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer)) print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt") weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights) model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer) test(model, device, test_loader, args.use_fp8_infer)
......
...@@ -18,21 +18,26 @@ path = sys.argv[1] ...@@ -18,21 +18,26 @@ path = sys.argv[1]
config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json" config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json"
class bcolors: class bcolors:
OKGREEN = '\033[92m' OKGREEN = "\033[92m"
WARNING = '\033[93m' WARNING = "\033[93m"
FAIL = '\033[91m' FAIL = "\033[91m"
ENDC = '\033[0m' ENDC = "\033[0m"
def print_ok(msg): def print_ok(msg):
print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}") print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}")
def print_fail(msg): def print_fail(msg):
print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}") print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}")
def print_warn(msg): def print_warn(msg):
print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}") print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}")
with open(config_path, "r") as f: with open(config_path, "r") as f:
c = json.load(f) c = json.load(f)
current_year = datetime.date.today().year current_year = datetime.date.today().year
...@@ -41,7 +46,7 @@ with open(config_path, "r") as f: ...@@ -41,7 +46,7 @@ with open(config_path, "r") as f:
else: else:
year_string = str(c["initial_year"]) + "-" + str(current_year) year_string = str(c["initial_year"]) + "-" + str(current_year)
copyright_string = c["copyright"].replace("<YEAR>", year_string) copyright_string = c["copyright"].replace("<YEAR>", year_string)
license = c["license"].split('\n') license = c["license"].split("\n")
excludes = c["exclude"] excludes = c["exclude"]
root_path = os.path.abspath(path) root_path = os.path.abspath(path)
copyright_only = c["copyright_only"] copyright_only = c["copyright_only"]
...@@ -49,21 +54,25 @@ with open(config_path, "r") as f: ...@@ -49,21 +54,25 @@ with open(config_path, "r") as f:
has_gitignore = os.path.exists(root_path + "/.gitignore") has_gitignore = os.path.exists(root_path + "/.gitignore")
def strip_star_slash(s): def strip_star_slash(s):
ret = s ret = s
if ret.startswith('*'): if ret.startswith("*"):
ret = ret[1:] ret = ret[1:]
if ret.endswith('/'): if ret.endswith("/"):
ret = ret[:-1] ret = ret[:-1]
return ret return ret
if has_gitignore: if has_gitignore:
with open(root_path + "/.gitignore", "r") as f: with open(root_path + "/.gitignore", "r") as f:
for line in f.readlines(): for line in f.readlines():
excludes.append(strip_star_slash(line.strip())) excludes.append(strip_star_slash(line.strip()))
def get_file_type(path): def get_file_type(path):
ext = {"c": ["c", "cpp", "cu", "h", "cuh"], ext = {
"c": ["c", "cpp", "cu", "h", "cuh"],
"py": ["py"], "py": ["py"],
"rst": ["rst"], "rst": ["rst"],
"txt": ["txt"], "txt": ["txt"],
...@@ -77,8 +86,10 @@ def get_file_type(path): ...@@ -77,8 +86,10 @@ def get_file_type(path):
return filetype return filetype
return "unknown" return "unknown"
success = True success = True
def check_file(path): def check_file(path):
global success global success
N = 10 N = 10
...@@ -127,9 +138,10 @@ def check_file(path): ...@@ -127,9 +138,10 @@ def check_file(path):
if copyright_found and license_found: if copyright_found and license_found:
print_ok("OK") print_ok("OK")
for root, dirs, files in os.walk(root_path): for root, dirs, files in os.walk(root_path):
print(f"Entering {root}") print(f"Entering {root}")
hidden = [d for d in dirs if d.startswith('.')] + [f for f in files if f.startswith('.')] hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")]
all_excludes = excludes + hidden all_excludes = excludes + hidden
to_remove = [] to_remove = []
for d in dirs: for d in dirs:
......
...@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve() ...@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension from setuptools.command.build_ext import build_ext as BuildExtension
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks: elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import('pybind11') install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
...@@ -86,34 +87,45 @@ if __name__ == "__main__": ...@@ -86,34 +87,45 @@ if __name__ == "__main__":
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension from build_tools.pytorch import setup_pytorch_extension
ext_modules.append( ext_modules.append(
setup_pytorch_extension( setup_pytorch_extension(
"transformer_engine/pytorch/csrc", "transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc", current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
if "jax" in frameworks: if "jax" in frameworks:
from build_tools.jax import setup_jax_extension from build_tools.jax import setup_jax_extension
ext_modules.append( ext_modules.append(
setup_jax_extension( setup_jax_extension(
"transformer_engine/jax/csrc", "transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc", current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
if "paddle" in frameworks: if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension from build_tools.paddle import setup_paddle_extension
ext_modules.append( ext_modules.append(
setup_paddle_extension( setup_paddle_extension(
"transformer_engine/paddle/csrc", "transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc", current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name="transformer_engine", name="transformer_engine",
version=__version__, version=__version__,
packages=setuptools.find_packages( packages=setuptools.find_packages(
include=["transformer_engine", include=[
"transformer_engine",
"transformer_engine.*", "transformer_engine.*",
"transformer_engine/build_tools"], "transformer_engine/build_tools",
],
), ),
extras_require={ extras_require={
"test": test_requires, "test": test_requires,
...@@ -125,5 +137,5 @@ if __name__ == "__main__": ...@@ -125,5 +137,5 @@ if __name__ == "__main__":
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
include_package_data=True, include_package_data=True,
package_data={"": ["VERSION.txt"]} package_data={"": ["VERSION.txt"]},
) )
...@@ -6,7 +6,7 @@ import jax ...@@ -6,7 +6,7 @@ import jax
import pytest import pytest
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope="function")
def clear_live_arrays(): def clear_live_arrays():
""" """
Clear all live arrays to keep the resource clean Clear all live arrays to keep the resource clean
......
...@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough ...@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough
def generate_configs(): def generate_configs():
configs = [] configs = []
if is_devices_enough(2): if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')]) configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')]) configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
if is_devices_enough(4): if is_devices_enough(4):
TP_size = 2 TP_size = 2
DP_size = 2 DP_size = 2
configs.append( configs.append(
[4, (DP_size, TP_size), ('dp', 'tp'), [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
MeshResource(dp_resource='dp', tp_resource='tp')]) )
return configs return configs
...@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
bytes_count = 0 bytes_count = 0
def get_bytes_per_txt(t): def get_bytes_per_txt(t):
''' """
The pattern of t would be like: The pattern of t would be like:
'f32[]', 'f32[]',
'(f32[1024]{0}', '(f32[1024]{0}',
...@@ -54,22 +54,22 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -54,22 +54,22 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'f8E4M3FN[1024]{0}', 'f8E4M3FN[1024]{0}',
'i32[1024]{0}', 'i32[1024]{0}',
'bf16[1024,1024]{0}' 'bf16[1024,1024]{0}'
''' """
match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t) match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups() _, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8 bytes_of_type = int(bits_of_type) // 8
if shape == '': if shape == "":
num_of_elements = 1 num_of_elements = 1
else: else:
num_of_elements = reduce(operator.mul, map(int, shape.split(','))) num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
return bytes_of_type * num_of_elements return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...] # ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
if '(' in hlo_text[2]: if "(" in hlo_text[2]:
for txt in hlo_text[2:]: for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt) bytes_count += get_bytes_per_txt(txt)
if ')' in txt: if ")" in txt:
break break
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...] else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2]) bytes_count = get_bytes_per_txt(hlo_text[2])
...@@ -91,11 +91,13 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -91,11 +91,13 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
return result return result
target_result = count_collectives(target_splitted_hlo) target_result = count_collectives(target_splitted_hlo)
assert target_result == coll_count_ref, \ assert (
f"Expected collective count is {coll_count_ref}, but got {target_result}." target_result == coll_count_ref
), f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(target_func, def compare_ops(
target_func,
ref_func, ref_func,
inputs, inputs,
coll_count_ref, coll_count_ref,
...@@ -105,7 +107,8 @@ def compare_ops(target_func, ...@@ -105,7 +107,8 @@ def compare_ops(target_func,
metric_bwd_dtype=None, metric_bwd_dtype=None,
in_shardings=_UNSPECIFIED, in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED, out_shardings=_UNSPECIFIED,
**kwargs): **kwargs,
):
assert len(inputs) >= 1 assert len(inputs) >= 1
if metric_fwd_dtype is None: if metric_fwd_dtype is None:
......
This diff is collapsed.
...@@ -9,16 +9,8 @@ import jax.numpy as jnp ...@@ -9,16 +9,8 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from flax.linen import dot_product_attention from flax.linen import dot_product_attention
from jax import random from jax import random
from jax.sharding import ( from jax.sharding import Mesh, NamedSharding, PartitionSpec
Mesh, from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
NamedSharding,
PartitionSpec
)
from distributed_test_base import (
generate_configs,
generate_collectives_count,
compare_ops
)
from utils import make_causal_mask, make_self_mask from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
...@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import ( ...@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import (
fused_attn_kvpacked, fused_attn_kvpacked,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
QKVLayout QKVLayout,
) )
...@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16] ...@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSelfAttn: class TestDistributedSelfAttn:
def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, def generate_collectives_count_ref(
dtype): self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape _, seqlen, _, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
...@@ -57,8 +50,11 @@ class TestDistributedSelfAttn: ...@@ -57,8 +50,11 @@ class TestDistributedSelfAttn:
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \ bias = (
if with_bias else None random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
if with_bias
else None
)
mask = None mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK: if attn_mask_type == AttnMaskType.PADDING_MASK:
...@@ -66,39 +62,66 @@ class TestDistributedSelfAttn: ...@@ -66,39 +62,66 @@ class TestDistributedSelfAttn:
elif attn_mask_type == AttnMaskType.CAUSAL_MASK: elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen) mask = make_self_mask(batch, seqlen)
qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, qkv_pspec = PartitionSpec(
None) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \ )
if with_bias else None bias_pspec = (
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
if attn_mask_type != AttnMaskType.NO_MASK else None )
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize(
"attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'attn_bias_type', "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS]) )
@pytest.mark.parametrize('attn_mask_type', @pytest.mark.parametrize("dtype", DTYPES)
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) def test_self_attn(
@pytest.mark.parametrize('dtype', DTYPES) self,
def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, device_count,
attn_bias_type, attn_mask_type, dtype): mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
attn_mask_type,
dtype,
):
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
scaling_factor = 1.0 scaling_factor = 1.0
_, seqlen, _, num_head, hidden = data_shape _, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if not is_fused_attn_kernel_available(
attn_mask_type, dropout_prob, num_head, num_head, dtype,
seqlen, seqlen, hidden): dtype,
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
fused_attn_qkvpacked(qkv, fused_attn_qkvpacked(
qkv,
bias, bias,
mask, mask,
None, None,
...@@ -106,7 +129,9 @@ class TestDistributedSelfAttn: ...@@ -106,7 +129,9 @@ class TestDistributedSelfAttn:
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_prob, dropout_probability=dropout_prob,
is_training=is_training)) is_training=is_training,
)
)
def ref_func(qkv, bias, mask): def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3) query, key, value = jnp.split(qkv, [1, 2], axis=-3)
...@@ -114,7 +139,8 @@ class TestDistributedSelfAttn: ...@@ -114,7 +139,8 @@ class TestDistributedSelfAttn:
key = jnp.squeeze(key) key = jnp.squeeze(key)
value = jnp.squeeze(value) value = jnp.squeeze(value)
output = dot_product_attention(query, output = dot_product_attention(
query,
key, key,
value, value,
bias=bias, bias=bias,
...@@ -122,37 +148,43 @@ class TestDistributedSelfAttn: ...@@ -122,37 +148,43 @@ class TestDistributedSelfAttn:
deterministic=is_training, deterministic=is_training,
dropout_rate=dropout_prob, dropout_rate=dropout_prob,
dropout_rng=None, dropout_rng=None,
dtype=jnp.float32) dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype) return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS with_bias = attn_bias_type != AttnBiasType.NO_BIAS
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \ (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, with_bias, data_shape, mesh_resource, with_bias, attn_mask_type, dtype
attn_mask_type, dtype) )
collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes, collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, with_bias, mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
data_shape, dtype) )
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \ bias_ = (
if bias is not None else bias jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ )
if mask is not None else mask mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
grad_args = (0, 1) if with_bias else (0,) grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
compare_ops(target_func, compare_ops(
ref_func, [qkv_, bias_, mask_], target_func,
ref_func,
[qkv_, bias_, mask_],
collective_count_ref, collective_count_ref,
grad_args=grad_args, grad_args=grad_args,
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec), in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings)) out_shardings=(None, out_grad_shardings),
)
class TestDistributedCrossAttn: class TestDistributedCrossAttn:
...@@ -176,20 +208,26 @@ class TestDistributedCrossAttn: ...@@ -176,20 +208,26 @@ class TestDistributedCrossAttn:
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, kv_pspec = PartitionSpec(
None) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ )
if attn_mask_type != AttnMaskType.NO_MASK else None mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize('attn_mask_type', @pytest.mark.parametrize(
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
@pytest.mark.parametrize('dtype', DTYPES) )
def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, @pytest.mark.parametrize("dtype", DTYPES)
attn_mask_type, dtype): def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
...@@ -197,14 +235,25 @@ class TestDistributedCrossAttn: ...@@ -197,14 +235,25 @@ class TestDistributedCrossAttn:
_, seqlen, num_head, hidden = data_shape _, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type, if not is_fused_attn_kernel_available(
attn_mask_type, dropout_prob, num_head, num_head, dtype,
seqlen, seqlen, hidden): dtype,
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
fused_attn_kvpacked(q, fused_attn_kvpacked(
q,
kv, kv,
None, None,
mask, mask,
...@@ -213,7 +262,9 @@ class TestDistributedCrossAttn: ...@@ -213,7 +262,9 @@ class TestDistributedCrossAttn:
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_prob, dropout_probability=dropout_prob,
is_training=is_training)) is_training=is_training,
)
)
def ref_func(query, kv, mask): def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3) key, value = jnp.split(kv, [1], axis=-3)
...@@ -221,7 +272,8 @@ class TestDistributedCrossAttn: ...@@ -221,7 +272,8 @@ class TestDistributedCrossAttn:
key = jnp.squeeze(key) key = jnp.squeeze(key)
value = jnp.squeeze(value) value = jnp.squeeze(value)
output = dot_product_attention(query, output = dot_product_attention(
query,
key, key,
value, value,
bias=None, bias=None,
...@@ -229,26 +281,32 @@ class TestDistributedCrossAttn: ...@@ -229,26 +281,32 @@ class TestDistributedCrossAttn:
deterministic=is_training, deterministic=is_training,
dropout_rate=dropout_prob, dropout_rate=dropout_prob,
dropout_rng=None, dropout_rng=None,
dtype=jnp.float32) dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype) return jnp.mean(output).astype(dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \ (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype) data_shape, mesh_resource, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ mask_ = (
if mask is not None else mask jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
compare_ops(target_func,
ref_func, [q_, kv_, mask_], compare_ops(
target_func,
ref_func,
[q_, kv_, mask_],
collective_count_ref, collective_count_ref,
grad_args=(0, 1), grad_args=(0, 1),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec), in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec))) out_shardings=(None, (q_pspec, kv_pspec)),
)
...@@ -35,31 +35,44 @@ class TestDistributedLayernorm: ...@@ -35,31 +35,44 @@ class TestDistributedLayernorm:
else: else:
raise NotImplementedError raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None) g_pspec = b_pspec = (
PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec) return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype): def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ['layernorm', 'rmsnorm'] assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta # for loss, dgamma and dbeta
weight_count = 2 if ln_type == 'layernorm' else 1 weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize allreduce_total_bytes = (
return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled), all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
allgather=0, )
other=0) return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) )
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('zero_centered_gamma', [False, True]) @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('shard_weights', [False, True]) @pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, @pytest.mark.parametrize("zero_centered_gamma", [False, True])
zero_centered_gamma, shard_weights): @pytest.mark.parametrize("shard_weights", [False, True])
def test_layernorm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
zero_centered_gamma,
shard_weights,
):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'layernorm' ln_type = "layernorm"
def target_func(x, gamma, beta): def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)) return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
...@@ -75,10 +88,12 @@ class TestDistributedLayernorm: ...@@ -75,10 +88,12 @@ class TestDistributedLayernorm:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype) output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output) return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \ (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) data_shape, mesh_resource, dtype, shard_weights
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, )
data_shape, dtype) collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
...@@ -88,20 +103,25 @@ class TestDistributedLayernorm: ...@@ -88,20 +103,25 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, gamma_, beta_], target_func,
ref_func,
[x_, gamma_, beta_],
collective_count_ref, collective_count_ref,
grad_args=(0, 1, 2), grad_args=(0, 1, 2),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec), in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec))) out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
except AssertionError as err: except AssertionError as err:
# Layernorm should still produce the correct numerical result with # Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same # gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch # when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here. # and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err): if (
g_pspec[-1] is None and b_pspec[-1] is None
) or "Expected collective count" not in str(err):
raise err raise err
finally: finally:
for w in warns: for w in warns:
...@@ -110,13 +130,15 @@ class TestDistributedLayernorm: ...@@ -110,13 +130,15 @@ class TestDistributedLayernorm:
"unsupported sharding of gamma and/or beta" "unsupported sharding of gamma and/or beta"
) )
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('shard_weights', [False, True]) @pytest.mark.parametrize("shard_weights", [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights): def test_rmsnorm(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'rmsnorm' ln_type = "rmsnorm"
def target_func(x, gamma): def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon)) return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
...@@ -128,10 +150,12 @@ class TestDistributedLayernorm: ...@@ -128,10 +150,12 @@ class TestDistributedLayernorm:
output = y * gamma output = y * gamma
return jnp.mean(output) return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \ (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) data_shape, mesh_resource, dtype, shard_weights
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, )
data_shape, dtype) collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
...@@ -140,14 +164,17 @@ class TestDistributedLayernorm: ...@@ -140,14 +164,17 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, gamma_], target_func,
ref_func,
[x_, gamma_],
collective_count_ref, collective_count_ref,
grad_args=(0, 1), grad_args=(0, 1),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec), in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec))) out_shardings=(None, (x_pspec, g_pspec)),
)
except AssertionError as err: except AssertionError as err:
# RmsNorm should still produce the correct numerical result with # RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same # gamma/beta sharded. However, the collective count may not be the same
......
...@@ -15,18 +15,19 @@ from transformer_engine.jax import fp8_autocast ...@@ -15,18 +15,19 @@ from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import ( from transformer_engine.jax.sharding import (
HIDDEN_AXES, HIDDEN_TP_AXES, HIDDEN_AXES,
HIDDEN_TP_AXES,
BATCH_AXES, BATCH_AXES,
SEQLEN_TP_AXES, SEQLEN_AXES, SEQLEN_TP_AXES,
W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES SEQLEN_AXES,
W_NO_SHARD_AXES,
W_FSDP_AXES,
W_TP_AXES,
W_JOINED_AXES,
) )
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
from utils import ( from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
assert_allclose,
assert_tree_like_allclose,
is_devices_enough
)
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
...@@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs(): ...@@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs():
configs = [] configs = []
if is_devices_enough(2): if is_devices_enough(2):
configs.append( configs.append(
[2, (1, 2), ('fsdp', 'tp'), [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
MeshResource(fsdp_resource='fsdp', tp_resource='tp')]) )
if is_devices_enough(4): if is_devices_enough(4):
configs.append( configs.append(
[4, (2, 2), ('fsdp', 'tp'), [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
MeshResource(fsdp_resource='fsdp', tp_resource='tp')]) )
return configs return configs
...@@ -64,10 +65,12 @@ class TestDistributedLayernormMLP: ...@@ -64,10 +65,12 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype) gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), k1 = jax.random.normal(
dtype) / jnp.sqrt(hidden_in) subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
k2 = jax.random.normal(subkeys[2], ) / jnp.sqrt(hidden_in)
(INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(INTERMEDIATE) k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias: if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype) b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype) b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
...@@ -90,15 +93,27 @@ class TestDistributedLayernormMLP: ...@@ -90,15 +93,27 @@ class TestDistributedLayernormMLP:
scale_list_1: List[jnp.ndarray], scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray], scale_list_2: List[jnp.ndarray],
layernorm_type: str = "rmsnorm", layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ('gelu',), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True, use_bias: bool = True,
multi_gpus: bool = False, multi_gpus: bool = False,
) -> jnp.ndarray: ) -> jnp.ndarray:
fp8_meta_pkg1 = FP8MetaPackage(amax_list_1[0], scale_list_1[0], amax_list_1[1], fp8_meta_pkg1 = FP8MetaPackage(
scale_list_1[1], amax_list_1[2], scale_list_1[2]) amax_list_1[0],
fp8_meta_pkg2 = FP8MetaPackage(amax_list_2[0], scale_list_2[0], amax_list_2[1], scale_list_1[0],
scale_list_2[1], amax_list_2[2], scale_list_2[2]) amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
fp8_meta_pkg2 = FP8MetaPackage(
amax_list_2[0],
scale_list_2[0],
amax_list_2[1],
scale_list_2[1],
amax_list_2[2],
scale_list_2[2],
)
if multi_gpus: if multi_gpus:
layernorm_input_axes = LAYERNORM_INPUT_AXES layernorm_input_axes = LAYERNORM_INPUT_AXES
...@@ -111,60 +126,68 @@ class TestDistributedLayernormMLP: ...@@ -111,60 +126,68 @@ class TestDistributedLayernormMLP:
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2 # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean( return jnp.mean(
fused_layernorm_fp8_mlp(x, fused_layernorm_fp8_mlp(
x,
ln_scale, ln_scale,
None, [kernel_1, kernel_2], [bias_1, bias_2], None,
[kernel_1, kernel_2],
[bias_1, bias_2],
[fp8_meta_pkg1, fp8_meta_pkg2], [fp8_meta_pkg1, fp8_meta_pkg2],
layernorm_type, layernorm_type,
layernorm_input_axes=layernorm_input_axes, layernorm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes, dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes, dot_2_input_axes=dot_2_input_axes,
activation_type=activation_type, activation_type=activation_type,
use_bias=use_bias)) use_bias=use_bias,
)
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('input_shape', INPUT_SHAPE) @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear')]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_fp8_mlp_primitive(self, mesh_config, activation_type, use_bias, input_shape, def test_layernorm_fp8_mlp_primitive(
dtype): self, mesh_config, activation_type, use_bias, input_shape, dtype
):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = 'rmsnorm' layernorm_type = "rmsnorm"
fp8_amax_list_1 = [ fp8_amax_list_1 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32) jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
] ]
fp8_amax_list_2 = [ fp8_amax_list_2 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32) jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
] ]
fp8_scale_list_1 = [ fp8_scale_list_1 = [
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32) jnp.ones((1,), jnp.float32),
] ]
fp8_scale_list_2 = [ fp8_scale_list_2 = [
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32) jnp.ones((1,), jnp.float32),
] ]
inputs = [x, gamma, k1, k2, b1, b2] = \ inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
self.generate_inputs(input_shape, activation_type, use_bias, dtype) input_shape, activation_type, use_bias, dtype
)
inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2] inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
static_inputs = [layernorm_type, activation_type, use_bias] static_inputs = [layernorm_type, activation_type, use_bias]
value_and_grad_func = jax.value_and_grad(self.layernorm_fp8_mlp_prim_func, value_and_grad_func = jax.value_and_grad(
argnums=range(len(inputs))) self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
# Single GPU # Single GPU
single_jitter = jax.jit(value_and_grad_func, single_jitter = jax.jit(
static_argnums=range(len(inputs), value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs))
len(static_inputs) + len(inputs))) )
with fp8_autocast(enabled=True): with fp8_autocast(enabled=True):
single_fwd, single_grads = single_jitter(*inputs, *static_inputs) single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
...@@ -172,12 +195,12 @@ class TestDistributedLayernormMLP: ...@@ -172,12 +195,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource): with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec('fsdp', None, 'tp')) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec('tp', 'fsdp')) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding) k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding) k2_ = jax.device_put(k2, k2_sharding)
if use_bias: if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, 'tp')) b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding) b1_ = jax.device_put(b1, b1_sharding)
else: else:
b1_sharding = b1_ = None b1_sharding = b1_ = None
...@@ -186,17 +209,29 @@ class TestDistributedLayernormMLP: ...@@ -186,17 +209,29 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists # Position ref for sharding pspec lists
# x, gamma, k1, k2, b1, # x, gamma, k1, k2, b1,
# b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv # b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, in_shardings = (
None, None) None,
out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None,
None, None, None)) k1_sharding,
k2_sharding,
b1_sharding,
None,
None,
None,
None,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None),
)
multi_jitter = jax.jit(value_and_grad_func, multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings, in_shardings=in_shardings,
out_shardings=out_shardings, out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
len(static_inputs) + len(multi_inputs) + ) # +1 for multi_gpus
1)) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
...@@ -206,26 +241,28 @@ class TestDistributedLayernormMLP: ...@@ -206,26 +241,28 @@ class TestDistributedLayernormMLP:
if isinstance(multi_grads[i], list): if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list) assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose(m_grad, assert_allclose(
s_grad, m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
dtype=dtype, )
err_msg=f'multi_grads[{i}] is not close')
else: else:
assert_allclose(multi_grads[i], assert_allclose(
multi_grads[i],
single_grads[i], single_grads[i],
dtype=dtype, dtype=dtype,
err_msg=f'multi_grads[{i}] is not close') err_msg=f"multi_grads[{i}] is not close",
)
def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias, input_shape, dtype, def _test_layernorm_mlp(
use_fp8): self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8
):
batch, seqlen, hidden_in = input_shape batch, seqlen, hidden_in = input_shape
layernorm_type = 'rmsnorm' layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0) rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2) subkeys = jax.random.split(rng, 2)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {'params': subkeys[1]} init_rngs = {"params": subkeys[1]}
# Single GPUs # Single GPUs
with fp8_autocast(enabled=use_fp8): with fp8_autocast(enabled=use_fp8):
...@@ -238,16 +275,17 @@ class TestDistributedLayernormMLP: ...@@ -238,16 +275,17 @@ class TestDistributedLayernormMLP:
use_bias=use_bias, use_bias=use_bias,
) )
params_single = ln_mlp_single.init(init_rngs, x) params_single = ln_mlp_single.init(init_rngs, x)
mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single, mlp_out_single, ln_out_single = ln_mlp_single.apply(
x, params_single, x, deterministic=True
deterministic=True) )
# Multi GPUs # Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource): with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
ln_mlp_sharded = LayerNormMLP(layernorm_type=layernorm_type, ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False, transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
...@@ -262,41 +300,37 @@ class TestDistributedLayernormMLP: ...@@ -262,41 +300,37 @@ class TestDistributedLayernormMLP:
layernorm_input_axes=LAYERNORM_INPUT_AXES, layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES,
name='mlp') name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x) params_sharded = ln_mlp_sharded.init(init_rngs, x)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded, mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
x, params_sharded, x, deterministic=True
deterministic=True) )
# Make sure params values are the same # Make sure params values are the same
assert_tree_like_allclose(params_sharded['params'], params_single['params']) assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype) assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
@pytest.mark.parametrize('input_shape', INPUT_SHAPE) @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('silu', 'linear'), ('gelu', 'gelu')]) @pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype): def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
self._test_layernorm_mlp(mesh_config, self._test_layernorm_mlp(
activation_type, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
use_bias, )
input_shape,
dtype,
use_fp8=False)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear'), ('gelu', 'gelu')]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize('input_shape', INPUT_SHAPE) @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, def test_layernorm_fp8_mlp_layer(
dtype): self, mesh_config, activation_type, use_bias, input_shape, dtype
self._test_layernorm_mlp(mesh_config, ):
activation_type, self._test_layernorm_mlp(
use_bias, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True
input_shape, )
dtype,
use_fp8=True)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
# See LICENSE for license information. # See LICENSE for license information.
import transformer_engine.jax import transformer_engine.jax
print("OK") print("OK")
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment