Unverified Commit 85a91997 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Generalize quantization APIs for FP8/FP4/.. recipes (#2256)



* Initial API change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change all imports and api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix recipe tets
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix more tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix docs, tests, and make Jax change as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change internal uses of fp8_autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address nits
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rename file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CG function, and small test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change instances of make_graphed_callables internally
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix distributed tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test and add more docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup test imports and minimize internal file imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Make is_bf16_available public
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better docs and better api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* fix nvfp4 test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ca6fedcf
......@@ -264,11 +264,9 @@ def train_and_evaluate(args):
fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)
) as mesh, te.fp8_autocast(
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh, te.autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
recipe=fp8_recipe,
mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS),
):
......@@ -282,7 +280,7 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
# Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast
# Add TE logical axis rules to our Flax logical axis rule context. This must be done inside autocast
sharding_rules = te_flax.extend_logical_axis_rules(tuple())
with flax.linen.logical_axis_rules(sharding_rules):
encoder = Net(num_embed)
......
......@@ -393,9 +393,9 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, te.fp8_autocast(
) as mesh, te.autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
recipe=fp8_recipe,
mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
......@@ -413,7 +413,7 @@ def train_and_evaluate(args):
# Create custom Flax logical axis rules for sharding.
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
# Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast.
# Extend the logical axis rules with TE's rules. This must be done inside autocast.
sharding_rules = te_flax.extend_logical_axis_rules(customized_rules)
with flax.linen.logical_axis_rules(sharding_rules):
......
......@@ -227,8 +227,8 @@ def train_and_evaluate(args):
else:
fp8_recipe = None
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
with te.autocast(
enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
......
......@@ -6,13 +6,13 @@ This example uses MNIST training to demonstrate the Transformer Engine usage. Th
2. Define model: The `Net` class is a small CNN model for image classification. It has an option to switch between using `nn.Dense` provided by Flax and `te.DenseGeneral` provided by the Transformer Engine. This allows for easy comparison between the two libraries.
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.fp8_autocast` context manager. If fp8_autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If fp8_autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under fp8_autocast. If not, then fp8_autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword.
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.autocast` context manager. If autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under autocast. If not, then autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword.
4. Training process: In `apply_model`, the main difference between normal Flax usage and this example is, with FP8 training, the FP8 metadata has to be filled into the gradient function `grad_fn`. Otherwise, the Transformer Engine doesn't know how to cast the BF16 tensor into FP8 tensor at runtime correctly. The FP8 metadata doesn't belong in model parameters (`state.params`), so we need to manually combine the metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function.
5. Evaluating process: The evaluating process is the same as the training process. Need to ensure FP8 metadata is inside var_collect and fill it into loss function.
6. Additional options: The `te.fp8_autocast` context manager has additional options
6. Additional options: The `te.autocast` context manager has additional options
* FP8 Recipe: control FP8 training behavior. See the [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for a detailed explanation of FP8 recipes and the supported options.
## Run ##
......
......@@ -193,8 +193,8 @@ def train_and_evaluate(args):
else:
fp8_recipe = None
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
with te.autocast(
enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
......
......@@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
)
parser.add_argument(
"--no-comm-overlap",
......@@ -299,7 +299,7 @@ def _train(opts):
dist_print(" |-- Forward pass", group=tp_group, debug=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
......
......@@ -49,5 +49,5 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd
# ...
```
**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support
**NOTE:** This example has `autocast()` enabled by default. To run on GPUs without Fp8 support
(e.g.: A100), add the `--no-fp8` option to the commands shown above.
......@@ -173,7 +173,7 @@ def parse_fsdp_args():
"--no-fp8",
action="store_true",
default=False,
help="Disables the te.fp8_autocast() context.",
help="Disables the te.autocast() context.",
)
parser.add_argument(
"--no-defer-init",
......@@ -284,11 +284,11 @@ def train(opts):
dtype=opts.dtype,
device="cuda",
)
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
# autocast needs to be given the FSDP process group for amax reductions
with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the fp8_autocast context
# calculate gradient and take training step outside the autocast context
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
......
......@@ -52,7 +52,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with te.fp8_autocast(enabled=use_fp8):
with te.autocast(enabled=use_fp8):
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
......@@ -76,7 +76,7 @@ def calibrate(model, device, test_loader, fp8):
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=fp8, calibrating=True):
with te.autocast(enabled=fp8, calibrating=True):
output = model(data)
......@@ -88,7 +88,7 @@ def test(model, device, test_loader, use_fp8):
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8):
with te.autocast(enabled=use_fp8):
output = model(data)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
......
......@@ -15,7 +15,7 @@ from distributed_test_base import generate_configs
from utils import assert_allclose, pytest_parametrize_wrapper
import transformer_engine.jax.cpp_extensions as tex
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.jax.dense import dense
......@@ -127,7 +127,7 @@ class TestDistributedDense:
contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
# TE GEMM result
te_result = _jitted_gemm(
x_sharded,
......@@ -209,7 +209,7 @@ class TestDistributedDense:
contracting_dims = ((2,), (0,))
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
# Test gradients w.r.t. all inputs
te_grad_func = jax.jit(
jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)),
......
......@@ -9,7 +9,7 @@ import numpy as np
from utils import pytest_parametrize_wrapper, is_devices_enough
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
def generate_mesh_configs():
......@@ -26,10 +26,10 @@ def generate_mesh_configs():
class TestMeshResource(unittest.TestCase):
def test_fp8_autocast_with_mesh_resource(self):
def test_autocast_with_mesh_resource(self):
for mesh_config in generate_mesh_configs():
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
self.assertEqual(mesh_resource, global_mesh_resource())
......@@ -15,7 +15,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import pytest_parametrize_wrapper
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.common import recipe
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available
......@@ -133,7 +133,7 @@ class TestDistributedLayernorm:
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
......@@ -209,7 +209,7 @@ class TestDistributedLayernorm:
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
......
......@@ -23,7 +23,7 @@ from transformer_engine.jax.quantize import (
ScalingMode,
get_quantize_config_with_recipe,
)
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.sharding import (
......@@ -210,9 +210,9 @@ class TestDistributedLayernormMLP:
)
# Single GPU
with fp8_autocast(
with autocast(
enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
recipe=quantization_recipe,
mesh_resource=MeshResource(),
):
single_jitter = jax.jit(
......@@ -224,9 +224,9 @@ class TestDistributedLayernormMLP:
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
with mesh, autocast(
enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
recipe=quantization_recipe,
mesh_resource=mesh_resource,
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
......@@ -381,8 +381,8 @@ class TestDistributedLayernormMLP:
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs
with fp8_autocast(
enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=MeshResource()
with autocast(
enabled=use_fp8, recipe=quantization_recipe, mesh_resource=MeshResource()
):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
......@@ -399,8 +399,8 @@ class TestDistributedLayernormMLP:
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=mesh_resource
with mesh, autocast(
enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource
):
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
......
......@@ -15,7 +15,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -102,7 +102,7 @@ class TestDistributedSoftmax:
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
......
......@@ -22,7 +22,7 @@ from jax import value_and_grad, jit
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import (
AttnBiasType,
......@@ -771,7 +771,7 @@ class FusedAttnRunner:
],
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with self.mesh, autocast(mesh_resource=self.mesh_resource):
primitive_out = customcall_fused_dpa_jit(*customcall_args)
primitive_out = self.cp_inverse_reorder_fn(primitive_out)
......@@ -788,7 +788,7 @@ class FusedAttnRunner:
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with self.mesh, autocast(mesh_resource=self.mesh_resource):
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
......@@ -888,7 +888,7 @@ class FusedAttnRunner:
)
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with self.mesh, autocast(mesh_resource=self.mesh_resource):
primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
reference_out, reference_dgrad = jitted_reference(*args)
......@@ -959,7 +959,7 @@ class FusedAttnRunner:
)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with self.mesh, autocast(mesh_resource=self.mesh_resource):
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, self.coll_count_ref)
......
......@@ -17,7 +17,7 @@ from transformer_engine.common.recipe import (
NVFP4BlockScaling,
)
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.jax.quantize import (
get_quantize_config,
is_scaling_mode_supported,
......@@ -97,84 +97,78 @@ class TestFP8Functions(unittest.TestCase):
)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_delayed_scaling(self):
def test_autocast_delayed_scaling(self):
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(ds)
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(ds)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_current_scaling(self):
def test_autocast_current_scaling(self):
self._check_default_state()
with fp8_autocast(
enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource()
):
with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_block_scaling(self):
def test_autocast_mxfp8_block_scaling(self):
self._check_default_state()
with fp8_autocast(
enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()
):
with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
bs = MXFP8BlockScaling()
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
@unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
def test_fp8_autocast_nvfp4_block_scaling(self):
def test_autocast_nvfp4_block_scaling(self):
self._check_default_state()
with fp8_autocast(
enabled=False, fp8_recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()
):
with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
bs = NVFP4BlockScaling()
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
......
......@@ -28,7 +28,7 @@ from transformer_engine.jax.quantize import (
is_fp8_available,
update_collections,
TensorSource,
fp8_autocast,
autocast,
)
from transformer_engine.jax.sharding import MeshResource
......@@ -507,14 +507,14 @@ class BaseTester:
"""Test normal datatype forward"""
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
with autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
with autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -522,7 +522,7 @@ class BaseTester:
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -530,7 +530,7 @@ class BaseTester:
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
......
......@@ -8,16 +8,15 @@ import logging
from contextlib import nullcontext
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_cu_seqlens_on_cp_rank,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize
import transformer_engine_torch as tex
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
from transformer_engine.pytorch import (
autocast,
DotProductAttention,
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
......@@ -306,7 +305,7 @@ def run_dpa_with_cp(
############ run without CP ############
logging.info(f"[Rank {rank}] Run without context parallelism")
if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
......@@ -396,7 +395,7 @@ def run_dpa_with_cp(
if dtype == "fp8":
core_attn.fp8_initialized = False
core_attn.fp8_meta_tensors_initialized = False
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
else:
fp8_context = nullcontext()
......
......@@ -10,13 +10,22 @@ from typing import Any, Dict, Tuple, Union
import pytest
import torch
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
from transformer_engine.pytorch.attention.dot_product_attention import (
from transformer_engine.pytorch import (
TransformerLayer,
autocast,
quantized_model_init,
DotProductAttention,
MultiheadAttention,
get_device_compute_capability,
Quantizer,
is_fp8_available,
is_bf16_available,
)
from transformer_engine.pytorch.attention.dot_product_attention import (
_attention_backends,
)
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils,
check_set_window_size,
......@@ -29,18 +38,14 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import (
Quantizer,
prepare_for_saving,
restore_from_saved,
)
......@@ -56,7 +61,7 @@ from utils import (
)
# Check if hardware supports FP8
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
# Reset RNG seed and states
seed = 1234
......@@ -67,12 +72,12 @@ reset_rng_states()
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
fp8.FP8GlobalStateManager.reset()
FP8GlobalStateManager.reset()
# Define F16 data types to test
param_types = [torch.float16]
if is_bf16_compatible():
if is_bf16_available():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
......@@ -1592,7 +1597,7 @@ def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
......@@ -1609,7 +1614,7 @@ def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
......@@ -1644,7 +1649,7 @@ def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
......@@ -1820,7 +1825,7 @@ def _run_mha_fp8_vs_f16(
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe):
with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe):
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
......@@ -1892,7 +1897,7 @@ def _run_mha_fp8_vs_f16(
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1)
with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8_mha, recipe=fp8_recipe):
out = mha(
hidden_states,
attn_mask_type=config.attn_mask_type,
......@@ -2110,7 +2115,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_rec
return _DUMMY_CUDA_RNG_STATE_TRACKER
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa):
with quantized_model_init(enabled=fp8_dpa):
dpa = DotProductAttention(
config.num_heads,
config.head_dim_qk,
......@@ -2202,7 +2207,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_rec
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8_dpa, recipe=fp8_recipe):
out = dpa(
inp[0],
inp[1],
......@@ -2343,7 +2348,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
)
mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with autocast(enabled=True, recipe=fp8_recipe):
out = mha(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad)
......@@ -2541,7 +2546,7 @@ class _custom_mha_fp8(torch.autograd.Function):
)
proj_dgrad = ctx.dO_quantizer(grad_output)
fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s,
......
......@@ -10,7 +10,7 @@ import logging
import pytest
import torch
from transformer_engine.pytorch.utils import (
from transformer_engine.pytorch import (
get_device_compute_capability,
get_cudnn_version,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment