Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8):
with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data)
def test(model, device, test_loader, use_fp8):
"""Testing function."""
model.eval()
......@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8):
output = model(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
......@@ -150,9 +147,7 @@ def main():
default=False,
help="quickly check a single pass",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
......@@ -167,7 +162,10 @@ def main():
help="For Saving the current Model",
)
parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
......@@ -215,7 +213,7 @@ def main():
if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer))
print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer)
......
......@@ -18,21 +18,26 @@ path = sys.argv[1]
config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json"
class bcolors:
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
def print_ok(msg):
print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}")
def print_fail(msg):
print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}")
def print_warn(msg):
print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}")
with open(config_path, "r") as f:
c = json.load(f)
current_year = datetime.date.today().year
......@@ -41,7 +46,7 @@ with open(config_path, "r") as f:
else:
year_string = str(c["initial_year"]) + "-" + str(current_year)
copyright_string = c["copyright"].replace("<YEAR>", year_string)
license = c["license"].split('\n')
license = c["license"].split("\n")
excludes = c["exclude"]
root_path = os.path.abspath(path)
copyright_only = c["copyright_only"]
......@@ -49,36 +54,42 @@ with open(config_path, "r") as f:
has_gitignore = os.path.exists(root_path + "/.gitignore")
def strip_star_slash(s):
ret = s
if ret.startswith('*'):
if ret.startswith("*"):
ret = ret[1:]
if ret.endswith('/'):
if ret.endswith("/"):
ret = ret[:-1]
return ret
if has_gitignore:
with open(root_path + "/.gitignore", "r") as f:
for line in f.readlines():
excludes.append(strip_star_slash(line.strip()))
def get_file_type(path):
ext = {"c": ["c", "cpp", "cu", "h", "cuh"],
"py": ["py"],
"rst": ["rst"],
"txt": ["txt"],
"cfg": ["cfg"],
"sh": ["sh"],
"md": ["md"],
}
ext = {
"c": ["c", "cpp", "cu", "h", "cuh"],
"py": ["py"],
"rst": ["rst"],
"txt": ["txt"],
"cfg": ["cfg"],
"sh": ["sh"],
"md": ["md"],
}
tmp = path.split(".")
for filetype, ext_list in ext.items():
if tmp[-1] in ext_list:
return filetype
return "unknown"
success = True
def check_file(path):
global success
N = 10
......@@ -127,9 +138,10 @@ def check_file(path):
if copyright_found and license_found:
print_ok("OK")
for root, dirs, files in os.walk(root_path):
print(f"Entering {root}")
hidden = [d for d in dirs if d.startswith('.')] + [f for f in files if f.startswith('.')]
hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")]
all_excludes = excludes + hidden
to_remove = []
for d in dirs:
......
......@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import('pybind11')
install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension
......@@ -86,34 +87,45 @@ if __name__ == "__main__":
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine"))
current_file_path / "transformer_engine",
)
)
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension
ext_modules.append(
setup_jax_extension(
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine"))
current_file_path / "transformer_engine",
)
)
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension
ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine"))
current_file_path / "transformer_engine",
)
)
# Configure package
setuptools.setup(
name="transformer_engine",
version=__version__,
packages=setuptools.find_packages(
include=["transformer_engine",
"transformer_engine.*",
"transformer_engine/build_tools"],
include=[
"transformer_engine",
"transformer_engine.*",
"transformer_engine/build_tools",
],
),
extras_require={
"test": test_requires,
......@@ -125,5 +137,5 @@ if __name__ == "__main__":
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=True,
package_data={"": ["VERSION.txt"]}
package_data={"": ["VERSION.txt"]},
)
......@@ -132,7 +132,7 @@ void compute_bwd_ref(
for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size;
compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
buff + offset, scaling_factor, batches, heads, rows, cols);
}
}
......
......@@ -6,7 +6,7 @@ import jax
import pytest
@pytest.fixture(autouse=True, scope='function')
@pytest.fixture(autouse=True, scope="function")
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
......
......@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')])
configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
if is_devices_enough(4):
TP_size = 2
DP_size = 2
configs.append(
[4, (DP_size, TP_size), ('dp', 'tp'),
MeshResource(dp_resource='dp', tp_resource='tp')])
[4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
)
return configs
......@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
bytes_count = 0
def get_bytes_per_txt(t):
'''
"""
The pattern of t would be like:
'f32[]',
'(f32[1024]{0}',
......@@ -54,24 +54,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'f8E4M3FN[1024]{0}',
'i32[1024]{0}',
'bf16[1024,1024]{0}'
'''
match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t)
"""
match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8
if shape == '':
if shape == "":
num_of_elements = 1
else:
num_of_elements = reduce(operator.mul, map(int, shape.split(',')))
num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
if '(' in hlo_text[2]:
if "(" in hlo_text[2]:
for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt)
if ')' in txt:
if ")" in txt:
break
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2])
return bytes_count
......@@ -91,21 +91,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
return result
target_result = count_collectives(target_splitted_hlo)
assert target_result == coll_count_ref, \
f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(target_func,
ref_func,
inputs,
coll_count_ref,
*,
grad_args=None,
metric_fwd_dtype=None,
metric_bwd_dtype=None,
in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
**kwargs):
assert (
target_result == coll_count_ref
), f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(
target_func,
ref_func,
inputs,
coll_count_ref,
*,
grad_args=None,
metric_fwd_dtype=None,
metric_bwd_dtype=None,
in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
**kwargs,
):
assert len(inputs) >= 1
if metric_fwd_dtype is None:
......
This diff is collapsed.
......@@ -9,16 +9,8 @@ import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import (
Mesh,
NamedSharding,
PartitionSpec
)
from distributed_test_base import (
generate_configs,
generate_collectives_count,
compare_ops
)
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
......@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import (
fused_attn_kvpacked,
AttnBiasType,
AttnMaskType,
QKVLayout
QKVLayout,
)
......@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSelfAttn:
def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape,
dtype):
def generate_collectives_count_ref(
self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
......@@ -46,7 +39,7 @@ class TestDistributedSelfAttn:
idx = mesh_axes.index(mesh_resource.tp_resource)
tp_size = mesh_shape[idx]
all_reduce_loss_bytes = 4 # 1 * FP32
all_reduce_loss_bytes = 4 # 1 * FP32
bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
# for loss and dbias
......@@ -57,8 +50,11 @@ class TestDistributedSelfAttn:
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \
if with_bias else None
bias = (
random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
if with_bias
else None
)
mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK:
......@@ -66,47 +62,76 @@ class TestDistributedSelfAttn:
elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen)
qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
None)
bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \
if with_bias else None
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
if attn_mask_type != AttnMaskType.NO_MASK else None
qkv_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
)
bias_pspec = (
PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
)
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize(
'attn_bias_type',
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
attn_bias_type, attn_mask_type, dtype):
"attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
)
@pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
attn_mask_type,
dtype,
):
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
_, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
if not is_fused_attn_kernel_available(
dtype,
dtype,
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask):
return jnp.mean(
fused_attn_qkvpacked(qkv,
bias,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
fused_attn_qkvpacked(
qkv,
bias,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
)
)
def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
......@@ -114,52 +139,59 @@ class TestDistributedSelfAttn:
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=bias,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32)
output = dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, with_bias,
attn_mask_type, dtype)
collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes,
mesh_resource, with_bias,
data_shape, dtype)
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, with_bias, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \
if bias is not None else bias
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
if mask is not None else mask
bias_ = (
jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
)
mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
compare_ops(target_func,
ref_func, [qkv_, bias_, mask_],
collective_count_ref,
grad_args=grad_args,
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings))
compare_ops(
target_func,
ref_func,
[qkv_, bias_, mask_],
collective_count_ref,
grad_args=grad_args,
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings),
)
class TestDistributedCrossAttn:
def generate_collectives_count_ref(self):
# for loss
all_reduce_loss_bytes = 4 # 1 * FP32
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
......@@ -176,20 +208,26 @@ class TestDistributedCrossAttn:
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
None)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
if attn_mask_type != AttnMaskType.NO_MASK else None
kv_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
)
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
attn_mask_type, dtype):
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
......@@ -197,23 +235,36 @@ class TestDistributedCrossAttn:
_, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
if not is_fused_attn_kernel_available(
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask):
return jnp.mean(
fused_attn_kvpacked(q,
kv,
None,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
fused_attn_kvpacked(
q,
kv,
None,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
)
)
def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3)
......@@ -221,34 +272,41 @@ class TestDistributedCrossAttn:
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32)
output = dot_product_attention(
query,
key,
value,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
if mask is not None else mask
compare_ops(target_func,
ref_func, [q_, kv_, mask_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)))
mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
compare_ops(
target_func,
ref_func,
[q_, kv_, mask_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
)
......@@ -35,31 +35,44 @@ class TestDistributedLayernorm:
else:
raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
g_pspec = b_pspec = (
PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ['layernorm', 'rmsnorm']
all_reduce_loss_bytes = 4 # 1 * FP32
assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
weight_count = 2 if ln_type == 'layernorm' else 1
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled),
allgather=0,
other=0)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
@pytest.mark.parametrize('shard_weights', [False, True])
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype,
zero_centered_gamma, shard_weights):
weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = (
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("shard_weights", [False, True])
def test_layernorm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
zero_centered_gamma,
shard_weights,
):
epsilon = 1e-6
ln_type = 'layernorm'
ln_type = "layernorm"
def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
......@@ -75,10 +88,12 @@ class TestDistributedLayernorm:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
......@@ -88,20 +103,25 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
compare_ops(
target_func,
ref_func,
[x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err):
if (
g_pspec[-1] is None and b_pspec[-1] is None
) or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
......@@ -110,13 +130,15 @@ class TestDistributedLayernorm:
"unsupported sharding of gamma and/or beta"
)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('shard_weights', [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights):
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shard_weights", [False, True])
def test_rmsnorm(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
):
epsilon = 1e-6
ln_type = 'rmsnorm'
ln_type = "rmsnorm"
def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
......@@ -128,10 +150,12 @@ class TestDistributedLayernorm:
output = y * gamma
return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
(x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
......@@ -140,14 +164,17 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)))
compare_ops(
target_func,
ref_func,
[x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
)
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -3,4 +3,5 @@
# See LICENSE for license information.
import transformer_engine.jax
print("OK")
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment