Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8): ...@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8):
with te.fp8_autocast(enabled=fp8, calibrating=True): with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data) output = model(data)
def test(model, device, test_loader, use_fp8): def test(model, device, test_loader, use_fp8):
"""Testing function.""" """Testing function."""
model.eval() model.eval()
...@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8): ...@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
output = model(data) output = model(data)
test_loss += F.nll_loss( test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
output, target, reduction="sum" pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item() correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
...@@ -150,9 +147,7 @@ def main(): ...@@ -150,9 +147,7 @@ def main():
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument( parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument( parser.add_argument(
"--log-interval", "--log-interval",
type=int, type=int,
...@@ -167,7 +162,10 @@ def main(): ...@@ -167,7 +162,10 @@ def main():
help="For Saving the current Model", help="For Saving the current Model",
) )
parser.add_argument( parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" "--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
) )
parser.add_argument( parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only" "--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
...@@ -215,7 +213,7 @@ def main(): ...@@ -215,7 +213,7 @@ def main():
if args.save_model or args.use_fp8_infer: if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt") torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer)) print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt") weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights) model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer) test(model, device, test_loader, args.use_fp8_infer)
......
...@@ -18,21 +18,26 @@ path = sys.argv[1] ...@@ -18,21 +18,26 @@ path = sys.argv[1]
config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json" config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json"
class bcolors: class bcolors:
OKGREEN = '\033[92m' OKGREEN = "\033[92m"
WARNING = '\033[93m' WARNING = "\033[93m"
FAIL = '\033[91m' FAIL = "\033[91m"
ENDC = '\033[0m' ENDC = "\033[0m"
def print_ok(msg): def print_ok(msg):
print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}") print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}")
def print_fail(msg): def print_fail(msg):
print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}") print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}")
def print_warn(msg): def print_warn(msg):
print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}") print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}")
with open(config_path, "r") as f: with open(config_path, "r") as f:
c = json.load(f) c = json.load(f)
current_year = datetime.date.today().year current_year = datetime.date.today().year
...@@ -41,7 +46,7 @@ with open(config_path, "r") as f: ...@@ -41,7 +46,7 @@ with open(config_path, "r") as f:
else: else:
year_string = str(c["initial_year"]) + "-" + str(current_year) year_string = str(c["initial_year"]) + "-" + str(current_year)
copyright_string = c["copyright"].replace("<YEAR>", year_string) copyright_string = c["copyright"].replace("<YEAR>", year_string)
license = c["license"].split('\n') license = c["license"].split("\n")
excludes = c["exclude"] excludes = c["exclude"]
root_path = os.path.abspath(path) root_path = os.path.abspath(path)
copyright_only = c["copyright_only"] copyright_only = c["copyright_only"]
...@@ -49,36 +54,42 @@ with open(config_path, "r") as f: ...@@ -49,36 +54,42 @@ with open(config_path, "r") as f:
has_gitignore = os.path.exists(root_path + "/.gitignore") has_gitignore = os.path.exists(root_path + "/.gitignore")
def strip_star_slash(s): def strip_star_slash(s):
ret = s ret = s
if ret.startswith('*'): if ret.startswith("*"):
ret = ret[1:] ret = ret[1:]
if ret.endswith('/'): if ret.endswith("/"):
ret = ret[:-1] ret = ret[:-1]
return ret return ret
if has_gitignore: if has_gitignore:
with open(root_path + "/.gitignore", "r") as f: with open(root_path + "/.gitignore", "r") as f:
for line in f.readlines(): for line in f.readlines():
excludes.append(strip_star_slash(line.strip())) excludes.append(strip_star_slash(line.strip()))
def get_file_type(path): def get_file_type(path):
ext = {"c": ["c", "cpp", "cu", "h", "cuh"], ext = {
"py": ["py"], "c": ["c", "cpp", "cu", "h", "cuh"],
"rst": ["rst"], "py": ["py"],
"txt": ["txt"], "rst": ["rst"],
"cfg": ["cfg"], "txt": ["txt"],
"sh": ["sh"], "cfg": ["cfg"],
"md": ["md"], "sh": ["sh"],
} "md": ["md"],
}
tmp = path.split(".") tmp = path.split(".")
for filetype, ext_list in ext.items(): for filetype, ext_list in ext.items():
if tmp[-1] in ext_list: if tmp[-1] in ext_list:
return filetype return filetype
return "unknown" return "unknown"
success = True success = True
def check_file(path): def check_file(path):
global success global success
N = 10 N = 10
...@@ -127,9 +138,10 @@ def check_file(path): ...@@ -127,9 +138,10 @@ def check_file(path):
if copyright_found and license_found: if copyright_found and license_found:
print_ok("OK") print_ok("OK")
for root, dirs, files in os.walk(root_path): for root, dirs, files in os.walk(root_path):
print(f"Entering {root}") print(f"Entering {root}")
hidden = [d for d in dirs if d.startswith('.')] + [f for f in files if f.startswith('.')] hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")]
all_excludes = excludes + hidden all_excludes = excludes + hidden
to_remove = [] to_remove = []
for d in dirs: for d in dirs:
......
...@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve() ...@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension from setuptools.command.build_ext import build_ext as BuildExtension
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks: elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import('pybind11') install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
...@@ -86,34 +87,45 @@ if __name__ == "__main__": ...@@ -86,34 +87,45 @@ if __name__ == "__main__":
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension from build_tools.pytorch import setup_pytorch_extension
ext_modules.append( ext_modules.append(
setup_pytorch_extension( setup_pytorch_extension(
"transformer_engine/pytorch/csrc", "transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc", current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
if "jax" in frameworks: if "jax" in frameworks:
from build_tools.jax import setup_jax_extension from build_tools.jax import setup_jax_extension
ext_modules.append( ext_modules.append(
setup_jax_extension( setup_jax_extension(
"transformer_engine/jax/csrc", "transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc", current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
if "paddle" in frameworks: if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension from build_tools.paddle import setup_paddle_extension
ext_modules.append( ext_modules.append(
setup_paddle_extension( setup_paddle_extension(
"transformer_engine/paddle/csrc", "transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc", current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name="transformer_engine", name="transformer_engine",
version=__version__, version=__version__,
packages=setuptools.find_packages( packages=setuptools.find_packages(
include=["transformer_engine", include=[
"transformer_engine.*", "transformer_engine",
"transformer_engine/build_tools"], "transformer_engine.*",
"transformer_engine/build_tools",
],
), ),
extras_require={ extras_require={
"test": test_requires, "test": test_requires,
...@@ -125,5 +137,5 @@ if __name__ == "__main__": ...@@ -125,5 +137,5 @@ if __name__ == "__main__":
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
include_package_data=True, include_package_data=True,
package_data={"": ["VERSION.txt"]} package_data={"": ["VERSION.txt"]},
) )
...@@ -132,7 +132,7 @@ void compute_bwd_ref( ...@@ -132,7 +132,7 @@ void compute_bwd_ref(
for (int b = 0; b < batches; ++b) { for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) { for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size; size_t offset = b * batch_size + h * head_size;
compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset, compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
buff + offset, scaling_factor, batches, heads, rows, cols); buff + offset, scaling_factor, batches, heads, rows, cols);
} }
} }
......
...@@ -6,7 +6,7 @@ import jax ...@@ -6,7 +6,7 @@ import jax
import pytest import pytest
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope="function")
def clear_live_arrays(): def clear_live_arrays():
""" """
Clear all live arrays to keep the resource clean Clear all live arrays to keep the resource clean
......
...@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough ...@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough
def generate_configs(): def generate_configs():
configs = [] configs = []
if is_devices_enough(2): if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')]) configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')]) configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
if is_devices_enough(4): if is_devices_enough(4):
TP_size = 2 TP_size = 2
DP_size = 2 DP_size = 2
configs.append( configs.append(
[4, (DP_size, TP_size), ('dp', 'tp'), [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
MeshResource(dp_resource='dp', tp_resource='tp')]) )
return configs return configs
...@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
bytes_count = 0 bytes_count = 0
def get_bytes_per_txt(t): def get_bytes_per_txt(t):
''' """
The pattern of t would be like: The pattern of t would be like:
'f32[]', 'f32[]',
'(f32[1024]{0}', '(f32[1024]{0}',
...@@ -54,24 +54,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -54,24 +54,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'f8E4M3FN[1024]{0}', 'f8E4M3FN[1024]{0}',
'i32[1024]{0}', 'i32[1024]{0}',
'bf16[1024,1024]{0}' 'bf16[1024,1024]{0}'
''' """
match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t) match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups() _, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8 bytes_of_type = int(bits_of_type) // 8
if shape == '': if shape == "":
num_of_elements = 1 num_of_elements = 1
else: else:
num_of_elements = reduce(operator.mul, map(int, shape.split(','))) num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
return bytes_of_type * num_of_elements return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...] # ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
if '(' in hlo_text[2]: if "(" in hlo_text[2]:
for txt in hlo_text[2:]: for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt) bytes_count += get_bytes_per_txt(txt)
if ')' in txt: if ")" in txt:
break break
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...] else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2]) bytes_count = get_bytes_per_txt(hlo_text[2])
return bytes_count return bytes_count
...@@ -91,21 +91,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -91,21 +91,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
return result return result
target_result = count_collectives(target_splitted_hlo) target_result = count_collectives(target_splitted_hlo)
assert target_result == coll_count_ref, \ assert (
f"Expected collective count is {coll_count_ref}, but got {target_result}." target_result == coll_count_ref
), f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(target_func,
ref_func, def compare_ops(
inputs, target_func,
coll_count_ref, ref_func,
*, inputs,
grad_args=None, coll_count_ref,
metric_fwd_dtype=None, *,
metric_bwd_dtype=None, grad_args=None,
in_shardings=_UNSPECIFIED, metric_fwd_dtype=None,
out_shardings=_UNSPECIFIED, metric_bwd_dtype=None,
**kwargs): in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
**kwargs,
):
assert len(inputs) >= 1 assert len(inputs) >= 1
if metric_fwd_dtype is None: if metric_fwd_dtype is None:
......
This diff is collapsed.
...@@ -9,16 +9,8 @@ import jax.numpy as jnp ...@@ -9,16 +9,8 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from flax.linen import dot_product_attention from flax.linen import dot_product_attention
from jax import random from jax import random
from jax.sharding import ( from jax.sharding import Mesh, NamedSharding, PartitionSpec
Mesh, from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
NamedSharding,
PartitionSpec
)
from distributed_test_base import (
generate_configs,
generate_collectives_count,
compare_ops
)
from utils import make_causal_mask, make_self_mask from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
...@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import ( ...@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import (
fused_attn_kvpacked, fused_attn_kvpacked,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
QKVLayout QKVLayout,
) )
...@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16] ...@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSelfAttn: class TestDistributedSelfAttn:
def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, def generate_collectives_count_ref(
dtype): self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape _, seqlen, _, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
...@@ -46,7 +39,7 @@ class TestDistributedSelfAttn: ...@@ -46,7 +39,7 @@ class TestDistributedSelfAttn:
idx = mesh_axes.index(mesh_resource.tp_resource) idx = mesh_axes.index(mesh_resource.tp_resource)
tp_size = mesh_shape[idx] tp_size = mesh_shape[idx]
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled) allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
# for loss and dbias # for loss and dbias
...@@ -57,8 +50,11 @@ class TestDistributedSelfAttn: ...@@ -57,8 +50,11 @@ class TestDistributedSelfAttn:
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \ bias = (
if with_bias else None random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
if with_bias
else None
)
mask = None mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK: if attn_mask_type == AttnMaskType.PADDING_MASK:
...@@ -66,47 +62,76 @@ class TestDistributedSelfAttn: ...@@ -66,47 +62,76 @@ class TestDistributedSelfAttn:
elif attn_mask_type == AttnMaskType.CAUSAL_MASK: elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen) mask = make_self_mask(batch, seqlen)
qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, qkv_pspec = PartitionSpec(
None) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \ )
if with_bias else None bias_pspec = (
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
if attn_mask_type != AttnMaskType.NO_MASK else None )
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
'attn_bias_type', "attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS]) [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
@pytest.mark.parametrize('attn_mask_type', )
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) @pytest.mark.parametrize(
@pytest.mark.parametrize('dtype', DTYPES) "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, )
attn_bias_type, attn_mask_type, dtype): @pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
attn_mask_type,
dtype,
):
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
scaling_factor = 1.0 scaling_factor = 1.0
_, seqlen, _, num_head, hidden = data_shape _, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if not is_fused_attn_kernel_available(
attn_mask_type, dropout_prob, num_head, num_head, dtype,
seqlen, seqlen, hidden): dtype,
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
fused_attn_qkvpacked(qkv, fused_attn_qkvpacked(
bias, qkv,
mask, bias,
None, mask,
attn_bias_type=attn_bias_type, None,
attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type,
scaling_factor=scaling_factor, attn_mask_type=attn_mask_type,
dropout_probability=dropout_prob, scaling_factor=scaling_factor,
is_training=is_training)) dropout_probability=dropout_prob,
is_training=is_training,
)
)
def ref_func(qkv, bias, mask): def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3) query, key, value = jnp.split(qkv, [1, 2], axis=-3)
...@@ -114,52 +139,59 @@ class TestDistributedSelfAttn: ...@@ -114,52 +139,59 @@ class TestDistributedSelfAttn:
key = jnp.squeeze(key) key = jnp.squeeze(key)
value = jnp.squeeze(value) value = jnp.squeeze(value)
output = dot_product_attention(query, output = dot_product_attention(
key, query,
value, key,
bias=bias, value,
mask=mask, bias=bias,
deterministic=is_training, mask=mask,
dropout_rate=dropout_prob, deterministic=is_training,
dropout_rng=None, dropout_rate=dropout_prob,
dtype=jnp.float32) dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype) return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS with_bias = attn_bias_type != AttnBiasType.NO_BIAS
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \ (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, with_bias, data_shape, mesh_resource, with_bias, attn_mask_type, dtype
attn_mask_type, dtype) )
collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes, collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, with_bias, mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
data_shape, dtype) )
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \ bias_ = (
if bias is not None else bias jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ )
if mask is not None else mask mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
grad_args = (0, 1) if with_bias else (0,) grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
compare_ops(target_func, compare_ops(
ref_func, [qkv_, bias_, mask_], target_func,
collective_count_ref, ref_func,
grad_args=grad_args, [qkv_, bias_, mask_],
metric_fwd_dtype=dtype, collective_count_ref,
metric_bwd_dtype=dtype, grad_args=grad_args,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec), metric_fwd_dtype=dtype,
out_shardings=(None, out_grad_shardings)) metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings),
)
class TestDistributedCrossAttn: class TestDistributedCrossAttn:
def generate_collectives_count_ref(self): def generate_collectives_count_ref(self):
# for loss # for loss
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype): def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
...@@ -176,20 +208,26 @@ class TestDistributedCrossAttn: ...@@ -176,20 +208,26 @@ class TestDistributedCrossAttn:
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, kv_pspec = PartitionSpec(
None) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ )
if attn_mask_type != AttnMaskType.NO_MASK else None mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize('attn_mask_type', @pytest.mark.parametrize(
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
@pytest.mark.parametrize('dtype', DTYPES) )
def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, @pytest.mark.parametrize("dtype", DTYPES)
attn_mask_type, dtype): def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
...@@ -197,23 +235,36 @@ class TestDistributedCrossAttn: ...@@ -197,23 +235,36 @@ class TestDistributedCrossAttn:
_, seqlen, num_head, hidden = data_shape _, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type, if not is_fused_attn_kernel_available(
attn_mask_type, dropout_prob, num_head, num_head, dtype,
seqlen, seqlen, hidden): dtype,
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
fused_attn_kvpacked(q, fused_attn_kvpacked(
kv, q,
None, kv,
mask, None,
None, mask,
attn_bias_type=attn_bias_type, None,
attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type,
scaling_factor=scaling_factor, attn_mask_type=attn_mask_type,
dropout_probability=dropout_prob, scaling_factor=scaling_factor,
is_training=is_training)) dropout_probability=dropout_prob,
is_training=is_training,
)
)
def ref_func(query, kv, mask): def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3) key, value = jnp.split(kv, [1], axis=-3)
...@@ -221,34 +272,41 @@ class TestDistributedCrossAttn: ...@@ -221,34 +272,41 @@ class TestDistributedCrossAttn:
key = jnp.squeeze(key) key = jnp.squeeze(key)
value = jnp.squeeze(value) value = jnp.squeeze(value)
output = dot_product_attention(query, output = dot_product_attention(
key, query,
value, key,
bias=None, value,
mask=mask, bias=None,
deterministic=is_training, mask=mask,
dropout_rate=dropout_prob, deterministic=is_training,
dropout_rng=None, dropout_rate=dropout_prob,
dtype=jnp.float32) dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype) return jnp.mean(output).astype(dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \ (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype) data_shape, mesh_resource, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ mask_ = (
if mask is not None else mask jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
compare_ops(target_func,
ref_func, [q_, kv_, mask_], compare_ops(
collective_count_ref, target_func,
grad_args=(0, 1), ref_func,
metric_fwd_dtype=dtype, [q_, kv_, mask_],
metric_bwd_dtype=dtype, collective_count_ref,
in_shardings=(q_pspec, kv_pspec, mask_pspec), grad_args=(0, 1),
out_shardings=(None, (q_pspec, kv_pspec))) metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
)
...@@ -35,31 +35,44 @@ class TestDistributedLayernorm: ...@@ -35,31 +35,44 @@ class TestDistributedLayernorm:
else: else:
raise NotImplementedError raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None) g_pspec = b_pspec = (
PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec) return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype): def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ['layernorm', 'rmsnorm'] assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta # for loss, dgamma and dbeta
weight_count = 2 if ln_type == 'layernorm' else 1 weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize allreduce_total_bytes = (
return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled), all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
allgather=0, )
other=0) return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) )
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('zero_centered_gamma', [False, True]) @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('shard_weights', [False, True]) @pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, @pytest.mark.parametrize("zero_centered_gamma", [False, True])
zero_centered_gamma, shard_weights): @pytest.mark.parametrize("shard_weights", [False, True])
def test_layernorm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
zero_centered_gamma,
shard_weights,
):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'layernorm' ln_type = "layernorm"
def target_func(x, gamma, beta): def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)) return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
...@@ -75,10 +88,12 @@ class TestDistributedLayernorm: ...@@ -75,10 +88,12 @@ class TestDistributedLayernorm:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype) output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output) return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \ (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) data_shape, mesh_resource, dtype, shard_weights
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, )
data_shape, dtype) collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
...@@ -88,20 +103,25 @@ class TestDistributedLayernorm: ...@@ -88,20 +103,25 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, gamma_, beta_], target_func,
collective_count_ref, ref_func,
grad_args=(0, 1, 2), [x_, gamma_, beta_],
metric_fwd_dtype=dtype, collective_count_ref,
metric_bwd_dtype=dtype, grad_args=(0, 1, 2),
in_shardings=(x_pspec, g_pspec, b_pspec), metric_fwd_dtype=dtype,
out_shardings=(None, (x_pspec, g_pspec, b_pspec))) metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
except AssertionError as err: except AssertionError as err:
# Layernorm should still produce the correct numerical result with # Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same # gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch # when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here. # and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err): if (
g_pspec[-1] is None and b_pspec[-1] is None
) or "Expected collective count" not in str(err):
raise err raise err
finally: finally:
for w in warns: for w in warns:
...@@ -110,13 +130,15 @@ class TestDistributedLayernorm: ...@@ -110,13 +130,15 @@ class TestDistributedLayernorm:
"unsupported sharding of gamma and/or beta" "unsupported sharding of gamma and/or beta"
) )
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('shard_weights', [False, True]) @pytest.mark.parametrize("shard_weights", [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights): def test_rmsnorm(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'rmsnorm' ln_type = "rmsnorm"
def target_func(x, gamma): def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon)) return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
...@@ -128,10 +150,12 @@ class TestDistributedLayernorm: ...@@ -128,10 +150,12 @@ class TestDistributedLayernorm:
output = y * gamma output = y * gamma
return jnp.mean(output) return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \ (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) data_shape, mesh_resource, dtype, shard_weights
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, )
data_shape, dtype) collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
...@@ -140,14 +164,17 @@ class TestDistributedLayernorm: ...@@ -140,14 +164,17 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, gamma_], target_func,
collective_count_ref, ref_func,
grad_args=(0, 1), [x_, gamma_],
metric_fwd_dtype=dtype, collective_count_ref,
metric_bwd_dtype=dtype, grad_args=(0, 1),
in_shardings=(x_pspec, g_pspec), metric_fwd_dtype=dtype,
out_shardings=(None, (x_pspec, g_pspec))) metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
)
except AssertionError as err: except AssertionError as err:
# RmsNorm should still produce the correct numerical result with # RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same # gamma/beta sharded. However, the collective count may not be the same
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
# See LICENSE for license information. # See LICENSE for license information.
import transformer_engine.jax import transformer_engine.jax
print("OK") print("OK")
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment