Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8):
with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data)
def test(model, device, test_loader, use_fp8):
"""Testing function."""
model.eval()
......@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8):
output = model(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
......@@ -150,9 +147,7 @@ def main():
default=False,
help="quickly check a single pass",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
......@@ -167,7 +162,10 @@ def main():
help="For Saving the current Model",
)
parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
......@@ -215,7 +213,7 @@ def main():
if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer))
print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer)
......
......@@ -18,21 +18,26 @@ path = sys.argv[1]
config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json"
class bcolors:
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
def print_ok(msg):
print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}")
def print_fail(msg):
print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}")
def print_warn(msg):
print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}")
with open(config_path, "r") as f:
c = json.load(f)
current_year = datetime.date.today().year
......@@ -41,7 +46,7 @@ with open(config_path, "r") as f:
else:
year_string = str(c["initial_year"]) + "-" + str(current_year)
copyright_string = c["copyright"].replace("<YEAR>", year_string)
license = c["license"].split('\n')
license = c["license"].split("\n")
excludes = c["exclude"]
root_path = os.path.abspath(path)
copyright_only = c["copyright_only"]
......@@ -49,21 +54,25 @@ with open(config_path, "r") as f:
has_gitignore = os.path.exists(root_path + "/.gitignore")
def strip_star_slash(s):
ret = s
if ret.startswith('*'):
if ret.startswith("*"):
ret = ret[1:]
if ret.endswith('/'):
if ret.endswith("/"):
ret = ret[:-1]
return ret
if has_gitignore:
with open(root_path + "/.gitignore", "r") as f:
for line in f.readlines():
excludes.append(strip_star_slash(line.strip()))
def get_file_type(path):
ext = {"c": ["c", "cpp", "cu", "h", "cuh"],
ext = {
"c": ["c", "cpp", "cu", "h", "cuh"],
"py": ["py"],
"rst": ["rst"],
"txt": ["txt"],
......@@ -77,8 +86,10 @@ def get_file_type(path):
return filetype
return "unknown"
success = True
def check_file(path):
global success
N = 10
......@@ -127,9 +138,10 @@ def check_file(path):
if copyright_found and license_found:
print_ok("OK")
for root, dirs, files in os.walk(root_path):
print(f"Entering {root}")
hidden = [d for d in dirs if d.startswith('.')] + [f for f in files if f.startswith('.')]
hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")]
all_excludes = excludes + hidden
to_remove = []
for d in dirs:
......
......@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import('pybind11')
install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension
......@@ -86,34 +87,45 @@ if __name__ == "__main__":
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine"))
current_file_path / "transformer_engine",
)
)
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension
ext_modules.append(
setup_jax_extension(
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine"))
current_file_path / "transformer_engine",
)
)
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension
ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine"))
current_file_path / "transformer_engine",
)
)
# Configure package
setuptools.setup(
name="transformer_engine",
version=__version__,
packages=setuptools.find_packages(
include=["transformer_engine",
include=[
"transformer_engine",
"transformer_engine.*",
"transformer_engine/build_tools"],
"transformer_engine/build_tools",
],
),
extras_require={
"test": test_requires,
......@@ -125,5 +137,5 @@ if __name__ == "__main__":
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=True,
package_data={"": ["VERSION.txt"]}
package_data={"": ["VERSION.txt"]},
)
......@@ -6,7 +6,7 @@ import jax
import pytest
@pytest.fixture(autouse=True, scope='function')
@pytest.fixture(autouse=True, scope="function")
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
......
......@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')])
configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
if is_devices_enough(4):
TP_size = 2
DP_size = 2
configs.append(
[4, (DP_size, TP_size), ('dp', 'tp'),
MeshResource(dp_resource='dp', tp_resource='tp')])
[4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
)
return configs
......@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
bytes_count = 0
def get_bytes_per_txt(t):
'''
"""
The pattern of t would be like:
'f32[]',
'(f32[1024]{0}',
......@@ -54,22 +54,22 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'f8E4M3FN[1024]{0}',
'i32[1024]{0}',
'bf16[1024,1024]{0}'
'''
match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t)
"""
match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8
if shape == '':
if shape == "":
num_of_elements = 1
else:
num_of_elements = reduce(operator.mul, map(int, shape.split(',')))
num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
if '(' in hlo_text[2]:
if "(" in hlo_text[2]:
for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt)
if ')' in txt:
if ")" in txt:
break
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2])
......@@ -91,11 +91,13 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
return result
target_result = count_collectives(target_splitted_hlo)
assert target_result == coll_count_ref, \
f"Expected collective count is {coll_count_ref}, but got {target_result}."
assert (
target_result == coll_count_ref
), f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(target_func,
def compare_ops(
target_func,
ref_func,
inputs,
coll_count_ref,
......@@ -105,7 +107,8 @@ def compare_ops(target_func,
metric_bwd_dtype=None,
in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
**kwargs):
**kwargs,
):
assert len(inputs) >= 1
if metric_fwd_dtype is None:
......
This diff is collapsed.
......@@ -9,16 +9,8 @@ import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import (
Mesh,
NamedSharding,
PartitionSpec
)
from distributed_test_base import (
generate_configs,
generate_collectives_count,
compare_ops
)
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
......@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import (
fused_attn_kvpacked,
AttnBiasType,
AttnMaskType,
QKVLayout
QKVLayout,
)
......@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSelfAttn:
def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape,
dtype):
def generate_collectives_count_ref(
self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
......@@ -57,8 +50,11 @@ class TestDistributedSelfAttn:
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \
if with_bias else None
bias = (
random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
if with_bias
else None
)
mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK:
......@@ -66,39 +62,66 @@ class TestDistributedSelfAttn:
elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen)
qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
None)
bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \
if with_bias else None
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
if attn_mask_type != AttnMaskType.NO_MASK else None
qkv_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
)
bias_pspec = (
PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
)
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize(
"attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
)
@pytest.mark.parametrize(
'attn_bias_type',
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
attn_bias_type, attn_mask_type, dtype):
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
attn_mask_type,
dtype,
):
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
_, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
if not is_fused_attn_kernel_available(
dtype,
dtype,
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask):
return jnp.mean(
fused_attn_qkvpacked(qkv,
fused_attn_qkvpacked(
qkv,
bias,
mask,
None,
......@@ -106,7 +129,9 @@ class TestDistributedSelfAttn:
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
is_training=is_training,
)
)
def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
......@@ -114,7 +139,8 @@ class TestDistributedSelfAttn:
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
output = dot_product_attention(
query,
key,
value,
bias=bias,
......@@ -122,37 +148,43 @@ class TestDistributedSelfAttn:
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32)
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, with_bias,
attn_mask_type, dtype)
collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes,
mesh_resource, with_bias,
data_shape, dtype)
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, with_bias, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \
if bias is not None else bias
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
if mask is not None else mask
bias_ = (
jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
)
mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
compare_ops(target_func,
ref_func, [qkv_, bias_, mask_],
compare_ops(
target_func,
ref_func,
[qkv_, bias_, mask_],
collective_count_ref,
grad_args=grad_args,
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings))
out_shardings=(None, out_grad_shardings),
)
class TestDistributedCrossAttn:
......@@ -176,20 +208,26 @@ class TestDistributedCrossAttn:
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
None)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
if attn_mask_type != AttnMaskType.NO_MASK else None
kv_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
)
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
attn_mask_type, dtype):
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
......@@ -197,14 +235,25 @@ class TestDistributedCrossAttn:
_, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
if not is_fused_attn_kernel_available(
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask):
return jnp.mean(
fused_attn_kvpacked(q,
fused_attn_kvpacked(
q,
kv,
None,
mask,
......@@ -213,7 +262,9 @@ class TestDistributedCrossAttn:
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
is_training=is_training,
)
)
def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3)
......@@ -221,7 +272,8 @@ class TestDistributedCrossAttn:
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
output = dot_product_attention(
query,
key,
value,
bias=None,
......@@ -229,26 +281,32 @@ class TestDistributedCrossAttn:
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32)
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
if mask is not None else mask
compare_ops(target_func,
ref_func, [q_, kv_, mask_],
mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
compare_ops(
target_func,
ref_func,
[q_, kv_, mask_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)))
out_shardings=(None, (q_pspec, kv_pspec)),
)
......@@ -35,31 +35,44 @@ class TestDistributedLayernorm:
else:
raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
g_pspec = b_pspec = (
PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ['layernorm', 'rmsnorm']
assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
weight_count = 2 if ln_type == 'layernorm' else 1
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled),
allgather=0,
other=0)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
@pytest.mark.parametrize('shard_weights', [False, True])
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype,
zero_centered_gamma, shard_weights):
weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = (
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("shard_weights", [False, True])
def test_layernorm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
zero_centered_gamma,
shard_weights,
):
epsilon = 1e-6
ln_type = 'layernorm'
ln_type = "layernorm"
def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
......@@ -75,10 +88,12 @@ class TestDistributedLayernorm:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
......@@ -88,20 +103,25 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_, beta_],
compare_ops(
target_func,
ref_func,
[x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err):
if (
g_pspec[-1] is None and b_pspec[-1] is None
) or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
......@@ -110,13 +130,15 @@ class TestDistributedLayernorm:
"unsupported sharding of gamma and/or beta"
)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('shard_weights', [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights):
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shard_weights", [False, True])
def test_rmsnorm(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
):
epsilon = 1e-6
ln_type = 'rmsnorm'
ln_type = "rmsnorm"
def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
......@@ -128,10 +150,12 @@ class TestDistributedLayernorm:
output = y * gamma
return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
(x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
......@@ -140,14 +164,17 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_],
compare_ops(
target_func,
ref_func,
[x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)))
out_shardings=(None, (x_pspec, g_pspec)),
)
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
......
......@@ -15,18 +15,19 @@ from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import (
HIDDEN_AXES, HIDDEN_TP_AXES,
HIDDEN_AXES,
HIDDEN_TP_AXES,
BATCH_AXES,
SEQLEN_TP_AXES, SEQLEN_AXES,
W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
SEQLEN_TP_AXES,
SEQLEN_AXES,
W_NO_SHARD_AXES,
W_FSDP_AXES,
W_TP_AXES,
W_JOINED_AXES,
)
from transformer_engine.jax.sharding import MeshResource
from utils import (
assert_allclose,
assert_tree_like_allclose,
is_devices_enough
)
from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16]
......@@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs():
configs = []
if is_devices_enough(2):
configs.append(
[2, (1, 2), ('fsdp', 'tp'),
MeshResource(fsdp_resource='fsdp', tp_resource='tp')])
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ('fsdp', 'tp'),
MeshResource(fsdp_resource='fsdp', tp_resource='tp')])
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
)
return configs
......@@ -64,10 +65,12 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE),
dtype) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2],
(INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(INTERMEDIATE)
k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
......@@ -90,15 +93,27 @@ class TestDistributedLayernormMLP:
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ('gelu',),
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True,
multi_gpus: bool = False,
) -> jnp.ndarray:
fp8_meta_pkg1 = FP8MetaPackage(amax_list_1[0], scale_list_1[0], amax_list_1[1],
scale_list_1[1], amax_list_1[2], scale_list_1[2])
fp8_meta_pkg2 = FP8MetaPackage(amax_list_2[0], scale_list_2[0], amax_list_2[1],
scale_list_2[1], amax_list_2[2], scale_list_2[2])
fp8_meta_pkg1 = FP8MetaPackage(
amax_list_1[0],
scale_list_1[0],
amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
fp8_meta_pkg2 = FP8MetaPackage(
amax_list_2[0],
scale_list_2[0],
amax_list_2[1],
scale_list_2[1],
amax_list_2[2],
scale_list_2[2],
)
if multi_gpus:
layernorm_input_axes = LAYERNORM_INPUT_AXES
......@@ -111,60 +126,68 @@ class TestDistributedLayernormMLP:
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean(
fused_layernorm_fp8_mlp(x,
fused_layernorm_fp8_mlp(
x,
ln_scale,
None, [kernel_1, kernel_2], [bias_1, bias_2],
None,
[kernel_1, kernel_2],
[bias_1, bias_2],
[fp8_meta_pkg1, fp8_meta_pkg2],
layernorm_type,
layernorm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes,
activation_type=activation_type,
use_bias=use_bias))
use_bias=use_bias,
)
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('input_shape', INPUT_SHAPE)
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear')])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('use_bias', [True, False])
def test_layernorm_fp8_mlp_primitive(self, mesh_config, activation_type, use_bias, input_shape,
dtype):
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_fp8_mlp_primitive(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = 'rmsnorm'
layernorm_type = "rmsnorm"
fp8_amax_list_1 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_amax_list_2 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_scale_list_1 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32)
jnp.ones((1,), jnp.float32),
]
fp8_scale_list_2 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32)
jnp.ones((1,), jnp.float32),
]
inputs = [x, gamma, k1, k2, b1, b2] = \
self.generate_inputs(input_shape, activation_type, use_bias, dtype)
inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
input_shape, activation_type, use_bias, dtype
)
inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
static_inputs = [layernorm_type, activation_type, use_bias]
value_and_grad_func = jax.value_and_grad(self.layernorm_fp8_mlp_prim_func,
argnums=range(len(inputs)))
value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
# Single GPU
single_jitter = jax.jit(value_and_grad_func,
static_argnums=range(len(inputs),
len(static_inputs) + len(inputs)))
single_jitter = jax.jit(
value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs))
)
with fp8_autocast(enabled=True):
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
......@@ -172,12 +195,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec('fsdp', None, 'tp'))
k2_sharding = NamedSharding(mesh, PartitionSpec('tp', 'fsdp'))
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, 'tp'))
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
......@@ -186,17 +209,29 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None,
None, None)
out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding, None, None,
None, None, None))
in_shardings = (
None,
None,
k1_sharding,
k2_sharding,
b1_sharding,
None,
None,
None,
None,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None),
)
multi_jitter = jax.jit(value_and_grad_func,
multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(len(multi_inputs),
len(static_inputs) + len(multi_inputs) +
1)) # +1 for multi_gpus
static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
......@@ -206,26 +241,28 @@ class TestDistributedLayernormMLP:
if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose(m_grad,
s_grad,
dtype=dtype,
err_msg=f'multi_grads[{i}] is not close')
assert_allclose(
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
)
else:
assert_allclose(multi_grads[i],
assert_allclose(
multi_grads[i],
single_grads[i],
dtype=dtype,
err_msg=f'multi_grads[{i}] is not close')
err_msg=f"multi_grads[{i}] is not close",
)
def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias, input_shape, dtype,
use_fp8):
def _test_layernorm_mlp(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8
):
batch, seqlen, hidden_in = input_shape
layernorm_type = 'rmsnorm'
layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {'params': subkeys[1]}
init_rngs = {"params": subkeys[1]}
# Single GPUs
with fp8_autocast(enabled=use_fp8):
......@@ -238,16 +275,17 @@ class TestDistributedLayernormMLP:
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x)
mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single,
x,
deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True
)
# Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
ln_mlp_sharded = LayerNormMLP(layernorm_type=layernorm_type,
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
......@@ -262,41 +300,37 @@ class TestDistributedLayernormMLP:
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name='mlp')
name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded,
x,
deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True
)
# Make sure params values are the same
assert_tree_like_allclose(params_sharded['params'], params_single['params'])
assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
@pytest.mark.parametrize('input_shape', INPUT_SHAPE)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('silu', 'linear'), ('gelu', 'gelu')])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('use_bias', [True, False])
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
self._test_layernorm_mlp(mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False)
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear'), ('gelu', 'gelu')])
@pytest.mark.parametrize('use_bias', [True, False])
@pytest.mark.parametrize('input_shape', INPUT_SHAPE)
@pytest.mark.parametrize('dtype', DTYPES)
def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape,
dtype):
self._test_layernorm_mlp(mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm_fp8_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -3,4 +3,5 @@
# See LICENSE for license information.
import transformer_engine.jax
print("OK")
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment