Commit a207db1d authored by yuguo's avatar yuguo
Browse files
parents fbee8990 69365f88
...@@ -18,13 +18,14 @@ from flax.linen.attention import combine_masks ...@@ -18,13 +18,14 @@ from flax.linen.attention import combine_masks
from jax import lax, vmap from jax import lax, vmap
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
import pytest
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnMaskType, AttnMaskType,
canonicalize_attn_mask_type, canonicalize_attn_mask_type,
make_swa_mask, make_swa_mask,
) )
from transformer_engine.jax.fp8 import DType as TEDType from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -96,6 +97,62 @@ def combine_biases(*masks: Optional[Array]): ...@@ -96,6 +97,62 @@ def combine_biases(*masks: Optional[Array]):
return mask return mask
def parameterize_by_test_level(param_dict: dict, id_prefix: str = ""):
"""
Takes an input dictionary of parameters keyed by test type "L0", etc.
Returns a list of pytest parameters to be used in a parameterized test for the current test type
"""
DEFAULT_TEST_LEVEL = "L0"
test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL)
if test_level not in param_dict:
raise ValueError("Unsupported test level")
return values_to_named_params(param_dict[test_level], id_prefix)
def value_to_test_name_str(value):
"""Converts a value to how it should appear in a test name."""
if isinstance(value, tuple) or isinstance(value, list):
return "_".join([value_to_test_name_str(v) for v in value])
dtype_type = type(jnp.float32)
if isinstance(value, dtype_type):
return value.dtype
return str(value)
def value_to_named_param(value, id_prefix: str = ""):
param_type = type(pytest.param(0))
if isinstance(value, param_type):
return value
x = pytest.param(value, id=f"{id_prefix}_{value_to_test_name_str(value)}")
return x
def values_to_named_params(params, id_prefix: str = ""):
return [value_to_named_param(v, id_prefix=id_prefix) for v in params]
def pytest_parametrize_wrapper(param_name, param_values):
"""
A wrapper for pytest.mark.parametrize to allow for automatic
naming of tests based on the parameter values.
"""
id_prefix = param_name
if isinstance(param_values, dict):
param_values = parameterize_by_test_level(param_values, id_prefix=param_name)
elif "," not in param_name:
param_values = values_to_named_params(param_values, id_prefix=id_prefix)
# Currently comma separated parameters in one parametrize call aren't supported for automatic naming
# and will just be passed through with default pytest names
def decorator(func):
return pytest.mark.parametrize(param_name, param_values)(func)
return decorator
class DotProductAttention(nn.Module): class DotProductAttention(nn.Module):
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
scale_attn_logits: bool = True scale_attn_logits: bool = True
...@@ -140,6 +197,7 @@ class DotProductAttention(nn.Module): ...@@ -140,6 +197,7 @@ class DotProductAttention(nn.Module):
Returns: Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`. Output of shape `[batch, length, num_heads, v_depth_per_head]`.
""" """
input_dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank." assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
batch_dim = 1 if self.transpose_batch_sequence else 0 batch_dim = 1 if self.transpose_batch_sequence else 0
assert ( assert (
...@@ -152,7 +210,7 @@ class DotProductAttention(nn.Module): ...@@ -152,7 +210,7 @@ class DotProductAttention(nn.Module):
if self.scale_attn_logits: if self.scale_attn_logits:
head_dim = query.shape[-1] head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(self.dtype) depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
query = query / depth_scaling query = query / depth_scaling
# Casting logits and softmax computation for float32 for model stability. # Casting logits and softmax computation for float32 for model stability.
...@@ -181,7 +239,7 @@ class DotProductAttention(nn.Module): ...@@ -181,7 +239,7 @@ class DotProductAttention(nn.Module):
attn_weights = attn_weights + bias.astype(attn_weights.dtype) attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Normalize the attention weights across `kv_length` dimension. # Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype) attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
# Apply attention dropout. # Apply attention dropout.
if not deterministic and self.dropout_rate > 0.0: if not deterministic and self.dropout_rate > 0.0:
...@@ -191,16 +249,20 @@ class DotProductAttention(nn.Module): ...@@ -191,16 +249,20 @@ class DotProductAttention(nn.Module):
dropout_shape = list(attn_weights.shape) dropout_shape = list(attn_weights.shape)
dropout_rng = self.make_rng("dropout") dropout_rng = self.make_rng("dropout")
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype) multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
attn_weights = attn_weights.astype(value.dtype) # attn_weights = attn_weights.astype(input_dtype)
# Take the linear combination of `value`. # Take the linear combination of `value`.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape) return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
assert (
attn_weights.dtype == input_dtype
), f"input.dtype={input_dtype}, output.dtype={attn_weights.dtype}"
return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape) return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
...@@ -246,7 +308,6 @@ class DenseGeneral(nn.Module): ...@@ -246,7 +308,6 @@ class DenseGeneral(nn.Module):
features = _canonicalize_tuple(self.features) features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim) axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
...@@ -268,11 +329,14 @@ class DenseGeneral(nn.Module): ...@@ -268,11 +329,14 @@ class DenseGeneral(nn.Module):
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) y = lax.dot_general(
y = y.astype(input_dtype) inputs, kernel, ((axis, contract_ind), ((), ())), preferred_element_type=input_dtype
)
if bias is not None: if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
assert y.dtype == inputs.dtype, f"input.dtype={inputs.dtype}, output.dtype={y.dtype}"
return y return y
...@@ -352,6 +416,7 @@ class MlpBlock(nn.Module): ...@@ -352,6 +416,7 @@ class MlpBlock(nn.Module):
)( )(
x, deterministic=deterministic x, deterministic=deterministic
) # Broadcast along length. ) # Broadcast along length.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp")) x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
else: else:
...@@ -365,6 +430,7 @@ class MlpBlock(nn.Module): ...@@ -365,6 +430,7 @@ class MlpBlock(nn.Module):
bias_axes="embed", bias_axes="embed",
name="wo", name="wo",
)(x) )(x)
assert ( assert (
output.dtype == inputs.dtype output.dtype == inputs.dtype
), f"input.dtype={input.dtype}, output.dtype={output.dtype}" ), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
...@@ -391,7 +457,7 @@ def apply_rotary_pos_emb_alternate( ...@@ -391,7 +457,7 @@ def apply_rotary_pos_emb_alternate(
second_part = second_half * cos + first_half * sin second_part = second_half * cos + first_half * sin
first_part = first_part.astype(inputs.dtype) first_part = first_part.astype(inputs.dtype)
second_part = second_part.astype(inputs.dtype) second_part = second_part.astype(inputs.dtype)
return jnp.concatenate([first_part, second_part], axis=-1) return jnp.concatenate([first_part, second_part], axis=-1).astype(inputs.dtype)
def apply_rotary_pos_emb_consecutive( def apply_rotary_pos_emb_consecutive(
...@@ -425,7 +491,7 @@ def apply_rotary_pos_emb_consecutive( ...@@ -425,7 +491,7 @@ def apply_rotary_pos_emb_consecutive(
sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5) sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5)
outputs = inputs * cos + inputs_shifted * sin * sign outputs = inputs * cos + inputs_shifted * sin * sign
return outputs return outputs.astype(inputs.dtype)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
...@@ -559,6 +625,7 @@ class MultiHeadAttention(nn.Module): ...@@ -559,6 +625,7 @@ class MultiHeadAttention(nn.Module):
if self.fuse_qkv: if self.fuse_qkv:
if is_qkvpack: if is_qkvpack:
qkv_proj = DenseGeneral( qkv_proj = DenseGeneral(
axis=-1, axis=-1,
features=self.num_heads * self.head_dim * 3, features=self.num_heads * self.head_dim * 3,
...@@ -569,11 +636,13 @@ class MultiHeadAttention(nn.Module): ...@@ -569,11 +636,13 @@ class MultiHeadAttention(nn.Module):
name="qkv", name="qkv",
dtype=self.dtype, dtype=self.dtype,
)(inputs_kv) )(inputs_kv)
query, key, value = jnp.split( query, key, value = jnp.split(
qkv_proj, qkv_proj,
[self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2], [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1, axis=-1,
) )
else: else:
query = q_projection(kernel_init=query_init, name="query")(inputs_q) query = q_projection(kernel_init=query_init, name="query")(inputs_q)
...@@ -711,6 +780,7 @@ class MultiHeadAttention(nn.Module): ...@@ -711,6 +780,7 @@ class MultiHeadAttention(nn.Module):
# Convert the boolean attention mask to an attention bias. # Convert the boolean attention mask to an attention bias.
if mask is not None: if mask is not None:
# attention mask in the form of attention bias # attention mask in the form of attention bias
attention_bias = lax.select( attention_bias = lax.select(
mask > 0, mask > 0,
jnp.full(mask.shape, 0.0).astype(self.dtype), jnp.full(mask.shape, 0.0).astype(self.dtype),
...@@ -740,6 +810,7 @@ class MultiHeadAttention(nn.Module): ...@@ -740,6 +810,7 @@ class MultiHeadAttention(nn.Module):
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv")) x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions. # Back to the original inputs dimensions.
out = DenseGeneral( out = DenseGeneral(
features=inputs_q.shape[-1], # output dim is set to the input dim. features=inputs_q.shape[-1], # output dim is set to the input dim.
axis=-1, axis=-1,
...@@ -750,6 +821,7 @@ class MultiHeadAttention(nn.Module): ...@@ -750,6 +821,7 @@ class MultiHeadAttention(nn.Module):
dtype=self.dtype, dtype=self.dtype,
name="out", name="out",
)(x) )(x)
assert ( assert (
inputs_q.dtype == inputs_kv.dtype == out.dtype inputs_q.dtype == inputs_kv.dtype == out.dtype
), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}" ), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
...@@ -784,12 +856,11 @@ class LayerNorm(nn.Module): ...@@ -784,12 +856,11 @@ class LayerNorm(nn.Module):
scale = nn_partitioning.param_with_axes( scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), self.dtype, axes=("embed",) "scale", self.scale_init, (features,), self.dtype, axes=("embed",)
) )
scale = jnp.asarray(scale, input_dtype) x_ = x.astype(jnp.float32)
if self.layernorm_type == "layernorm": if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True) mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
y = (x - mean) * lax.rsqrt(var + self.epsilon) y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",) "ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
...@@ -803,9 +874,10 @@ class LayerNorm(nn.Module): ...@@ -803,9 +874,10 @@ class LayerNorm(nn.Module):
else: else:
assert self.layernorm_type == "rmsnorm" assert self.layernorm_type == "rmsnorm"
assert not self.zero_centered_gamma assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) mean2 = jnp.mean(lax.square(x_), axis=-1, keepdims=True)
y = x * lax.rsqrt(mean2 + self.epsilon) y = x_ * lax.rsqrt(mean2 + self.epsilon)
z = y * scale z = y * scale
z = z.astype(input_dtype)
assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}" assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
return z return z
...@@ -1085,9 +1157,11 @@ class EncoderLayer(nn.Module): ...@@ -1085,9 +1157,11 @@ class EncoderLayer(nn.Module):
fuse_wi=self.fuse_mlp_wi, fuse_wi=self.fuse_mlp_wi,
name="mlp", name="mlp",
)(y, deterministic=deterministic) )(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
y, deterministic=deterministic y, deterministic=deterministic
) )
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)( y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
...@@ -1103,6 +1177,7 @@ class EncoderLayer(nn.Module): ...@@ -1103,6 +1177,7 @@ class EncoderLayer(nn.Module):
dtype=self.dtype, dtype=self.dtype,
name="output_layernorm", name="output_layernorm",
)(y) )(y)
assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}" assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
return y return y
......
...@@ -318,7 +318,6 @@ def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): ...@@ -318,7 +318,6 @@ def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
device=device, device=device,
with_amax_reduction=True, with_amax_reduction=True,
amax_reduction_group=tp_group, amax_reduction_group=tp_group,
amax_reduction_size=tp_size,
) )
quantizer = quantizer_class( quantizer = quantizer_class(
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
......
...@@ -741,7 +741,6 @@ def _test_fp8_scale_update( ...@@ -741,7 +741,6 @@ def _test_fp8_scale_update(
fp8_format = transformer_engine.common.recipe.Format.HYBRID fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling( recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin, margin=margin,
interval=1,
fp8_format=fp8_format, fp8_format=fp8_format,
amax_history_len=amax_history_len, amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo, amax_compute_algo=amax_compute_algo,
......
...@@ -286,6 +286,12 @@ def run_dpa_with_cp( ...@@ -286,6 +286,12 @@ def run_dpa_with_cp(
else: else:
out_.backward(dout_) out_.backward(dout_)
if fp8_mha:
assert isinstance(out, Float8Tensor)
assert isinstance(out_, Float8Tensor)
out = out.dequantize()
out_ = out_.dequantize()
for x in [out_, q_.grad, k_.grad, v_.grad]: for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x)) assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x)) assert torch.all(~torch.isinf(x))
......
...@@ -229,7 +229,7 @@ def get_model( ...@@ -229,7 +229,7 @@ def get_model(
attn_mask_type = "causal" attn_mask_type = "causal"
qkv_format = "bshd" qkv_format = "bshd"
if mode == "inference": if mode == "inference":
attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding" attn_mask_type = "padding_causal"
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
...@@ -392,9 +392,9 @@ def get_tols(module, backend, dtype): ...@@ -392,9 +392,9 @@ def get_tols(module, backend, dtype):
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"]) @pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"])
@pytest.mark.parametrize("is_cuda_graph", [False, True]) @pytest.mark.parametrize("is_cuda_graph", [False, True])
@pytest.mark.parametrize("is_fp8", [False, True]) @pytest.mark.parametrize("is_fp8", [False, True])
def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8): def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8):
reset_rng_states() reset_rng_states()
logger = logging.getLogger("test_paged_attn") logger = logging.getLogger("test_kv_cache")
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
fp8_format=recipe.Format.HYBRID, fp8_format=recipe.Format.HYBRID,
...@@ -407,7 +407,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda ...@@ -407,7 +407,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda
fp8_meta["recipe"] = fp8_recipe fp8_meta["recipe"] = fp8_recipe
config = model_configs_infer[model] config = model_configs_infer[model]
num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 num_layers = 2 if module == "TransformerLayer" else 1
# flash-attn v2 requires page_size >= 256 # flash-attn v2 requires page_size >= 256
if backend == "FlashAttention" and not fa_utils.v3_is_installed: if backend == "FlashAttention" and not fa_utils.v3_is_installed:
config_max_seqlen_q = config.max_seqlen_q config_max_seqlen_q = config.max_seqlen_q
...@@ -437,7 +437,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda ...@@ -437,7 +437,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda
# initialize inference_params # initialize inference_params
inference_params = InferenceParams( inference_params = InferenceParams(
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_seqlen_kv=config.max_seqlen_kv, max_sequence_length=config.max_seqlen_kv,
num_heads_kv=config.num_gqa_groups, num_heads_kv=config.num_gqa_groups,
head_dim_k=config.head_dim_qk, head_dim_k=config.head_dim_qk,
head_dim_v=config.head_dim_v, head_dim_v=config.head_dim_v,
......
...@@ -57,6 +57,7 @@ model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} ...@@ -57,6 +57,7 @@ model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
fp8_recipes = [ fp8_recipes = [
recipe.DelayedScaling(), recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(), recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
] ]
# Supported data types # Supported data types
......
...@@ -297,7 +297,6 @@ class TestFuser: ...@@ -297,7 +297,6 @@ class TestFuser:
fp8_format = transformer_engine.common.recipe.Format.HYBRID fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling( recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin, margin=margin,
interval=1,
fp8_format=fp8_format, fp8_format=fp8_format,
amax_history_len=8, amax_history_len=8,
amax_compute_algo="max", amax_compute_algo="max",
......
...@@ -56,3 +56,10 @@ def test_torch_dynamo(model_name: str): ...@@ -56,3 +56,10 @@ def test_torch_dynamo(model_name: str):
# Forward and backward pass # Forward and backward pass
out = model(*inputs) out = model(*inputs)
out.backward(torch.zeros_like(out)) out.backward(torch.zeros_like(out))
def test_lazy_compile():
"""Smoke test to ensure lazy compilation is working."""
from transformer_engine.pytorch.jit import dgelu_fused_
dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10))
...@@ -2144,7 +2144,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2144,7 +2144,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
inference_params = InferenceParams( inference_params = InferenceParams(
max_batch_size=B_max, max_batch_size=B_max,
max_seqlen_kv=S_max, max_sequence_length=S_max,
num_heads_kv=H, num_heads_kv=H,
head_dim_k=head_size, head_dim_k=head_size,
dtype=dtype, dtype=dtype,
......
...@@ -177,7 +177,6 @@ class TestFP8Recipe: ...@@ -177,7 +177,6 @@ class TestFP8Recipe:
fp8_format = transformer_engine.common.recipe.Format.HYBRID fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling( recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin, margin=margin,
interval=1,
fp8_format=fp8_format, fp8_format=fp8_format,
amax_history_len=amax_history_len, amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo, amax_compute_algo=amax_compute_algo,
......
...@@ -110,32 +110,17 @@ model_configs = { ...@@ -110,32 +110,17 @@ model_configs = {
} }
fp8_recipes = [ fp8_recipes = [
None, # Handles non-FP8 case None, # Test non-FP8
recipe.MXFP8BlockScaling(), recipe.MXFP8BlockScaling(), # Test default
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3), recipe.Float8CurrentScaling(), # Test default
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID), recipe.DelayedScaling(), # Test default
recipe.DelayedScaling( recipe.DelayedScaling( # Test most_recent algo
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16, amax_history_len=16,
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
), ),
recipe.DelayedScaling( recipe.DelayedScaling( # Test custom amax and scale compute algo
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="max",
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3, fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo=custom_amax_compute, amax_compute_algo=custom_amax_compute,
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
scaling_factor_compute_algo=custom_amax_to_scale, scaling_factor_compute_algo=custom_amax_to_scale,
), ),
] ]
...@@ -567,6 +552,8 @@ def test_sanity_grouped_linear( ...@@ -567,6 +552,8 @@ def test_sanity_grouped_linear(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8(): if fp8_recipe.mxfp8():
pytest.skip("Grouped linear does not support MXFP8") pytest.skip("Grouped linear does not support MXFP8")
if fp8_recipe.float8_current_scaling():
pytest.skip("Grouped linear does not support FP8 current scaling")
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
......
...@@ -19,9 +19,4 @@ try: ...@@ -19,9 +19,4 @@ try:
except (ImportError, StopIteration) as e: except (ImportError, StopIteration) as e:
pass pass
try:
import transformer_engine_jax
except ImportError:
pass
__version__ = str(metadata.version("transformer_engine")) __version__ = str(metadata.version("transformer_engine"))
...@@ -233,7 +233,8 @@ if (USE_CUDA) ...@@ -233,7 +233,8 @@ if (USE_CUDA)
# Configure dependencies # Configure dependencies
target_link_libraries(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC
CUDA::cublas CUDA::cublas
CUDA::cudart) CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
......
...@@ -771,6 +771,16 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -771,6 +771,16 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
} }
} }
#ifndef __HIP_PLATFORM_AMD__
namespace transformer_engine {
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }
} // namespace transformer_engine
#endif
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const NVTETensor *bias, NVTETensor *pre_gelu_out,
......
...@@ -140,6 +140,13 @@ constexpr int num_batchgemm_streams = 1; ...@@ -140,6 +140,13 @@ constexpr int num_batchgemm_streams = 1;
constexpr int num_streams = 4; constexpr int num_streams = 4;
#endif #endif
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
* region. This function is a helper to call cublasCreate() which allocate memory for the handle.
* The function will be called in the initialize phase of the related XLA custom calls.
*/
void nvte_cublas_handle_init();
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_H_ #endif // TRANSFORMER_ENGINE_GEMM_H_
...@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor ...@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void nvte_enable_cudnn_norm_fwd(bool enable); void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable); void nvte_enable_cudnn_norm_bwd(bool enable);
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -80,7 +80,8 @@ enum NVTEScalingMode { ...@@ -80,7 +80,8 @@ enum NVTEScalingMode {
/*! Single scale per block of 32 elements consecutive in either /*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */ rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING = 1, NVTE_MXFP8_1D_SCALING = 1,
NVTE_INVALID_SCALING NVTE_INVALID_SCALING = 2,
NVTE_NO_SCALING = 3
}; };
/*! \brief TE Tensor type /*! \brief TE Tensor type
...@@ -346,6 +347,13 @@ enum class DType { ...@@ -346,6 +347,13 @@ enum class DType {
kNumTypes kNumTypes
}; };
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
*/
bool is_fp8_dtype(const DType t);
/*! \struct TensorWrapper /*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class. * \brief C++ wrapper for the NVTETensor class.
*/ */
......
...@@ -11,10 +11,12 @@ ...@@ -11,10 +11,12 @@
transformer_engine::ubuf_built_with_mpi*; transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*; *transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*; transformer_engine::nvte_cudnn_handle_init*;
transformer_engine::nvte_cublas_handle_init*;
transformer_engine::typeToSize*; transformer_engine::typeToSize*;
transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*; *transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore* *transformer_engine::CommOverlapCore*
}; };
local: *; local: *;
}; };
\ No newline at end of file
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h> #include <cudnn_frontend_utils.h>
#endif #endif
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <functional> #include <functional>
...@@ -141,7 +142,6 @@ struct BackwardKernelParams : public KernelParamsBase { ...@@ -141,7 +142,6 @@ struct BackwardKernelParams : public KernelParamsBase {
}; };
enum class NVTE_Norm_Backend { Te, Cudnn }; enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
enum class NVTE_Norm_Stage { Forward, Backward }; enum class NVTE_Norm_Stage { Forward, Backward };
using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>; using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>;
......
...@@ -162,7 +162,6 @@ class DelayedScaling(Recipe): ...@@ -162,7 +162,6 @@ class DelayedScaling(Recipe):
""" """
margin: int = 0 margin: int = 0
interval: int = -1
fp8_format: Format = Format.HYBRID fp8_format: Format = Format.HYBRID
amax_history_len: int = 1024 amax_history_len: int = 1024
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max" amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max"
...@@ -173,12 +172,6 @@ class DelayedScaling(Recipe): ...@@ -173,12 +172,6 @@ class DelayedScaling(Recipe):
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
if self.interval >= 0:
warnings.warn(
"`interval` argument is deprecated and unused. "
"It will be removed in an upcoming release.",
DeprecationWarning,
)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment