Commit a207db1d authored by yuguo's avatar yuguo
Browse files
parents fbee8990 69365f88
......@@ -18,13 +18,14 @@ from flax.linen.attention import combine_masks
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
import pytest
from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type,
make_swa_mask,
)
from transformer_engine.jax.fp8 import DType as TEDType
from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -96,6 +97,62 @@ def combine_biases(*masks: Optional[Array]):
return mask
def parameterize_by_test_level(param_dict: dict, id_prefix: str = ""):
"""
Takes an input dictionary of parameters keyed by test type "L0", etc.
Returns a list of pytest parameters to be used in a parameterized test for the current test type
"""
DEFAULT_TEST_LEVEL = "L0"
test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL)
if test_level not in param_dict:
raise ValueError("Unsupported test level")
return values_to_named_params(param_dict[test_level], id_prefix)
def value_to_test_name_str(value):
"""Converts a value to how it should appear in a test name."""
if isinstance(value, tuple) or isinstance(value, list):
return "_".join([value_to_test_name_str(v) for v in value])
dtype_type = type(jnp.float32)
if isinstance(value, dtype_type):
return value.dtype
return str(value)
def value_to_named_param(value, id_prefix: str = ""):
param_type = type(pytest.param(0))
if isinstance(value, param_type):
return value
x = pytest.param(value, id=f"{id_prefix}_{value_to_test_name_str(value)}")
return x
def values_to_named_params(params, id_prefix: str = ""):
return [value_to_named_param(v, id_prefix=id_prefix) for v in params]
def pytest_parametrize_wrapper(param_name, param_values):
"""
A wrapper for pytest.mark.parametrize to allow for automatic
naming of tests based on the parameter values.
"""
id_prefix = param_name
if isinstance(param_values, dict):
param_values = parameterize_by_test_level(param_values, id_prefix=param_name)
elif "," not in param_name:
param_values = values_to_named_params(param_values, id_prefix=id_prefix)
# Currently comma separated parameters in one parametrize call aren't supported for automatic naming
# and will just be passed through with default pytest names
def decorator(func):
return pytest.mark.parametrize(param_name, param_values)(func)
return decorator
class DotProductAttention(nn.Module):
transpose_batch_sequence: bool = True
scale_attn_logits: bool = True
......@@ -140,6 +197,7 @@ class DotProductAttention(nn.Module):
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
input_dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
batch_dim = 1 if self.transpose_batch_sequence else 0
assert (
......@@ -152,7 +210,7 @@ class DotProductAttention(nn.Module):
if self.scale_attn_logits:
head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(self.dtype)
depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
query = query / depth_scaling
# Casting logits and softmax computation for float32 for model stability.
......@@ -181,7 +239,7 @@ class DotProductAttention(nn.Module):
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype)
attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
# Apply attention dropout.
if not deterministic and self.dropout_rate > 0.0:
......@@ -191,16 +249,20 @@ class DotProductAttention(nn.Module):
dropout_shape = list(attn_weights.shape)
dropout_rng = self.make_rng("dropout")
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
attn_weights = attn_weights.astype(value.dtype)
# attn_weights = attn_weights.astype(input_dtype)
# Take the linear combination of `value`.
if self.transpose_batch_sequence:
return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
assert (
attn_weights.dtype == input_dtype
), f"input.dtype={input_dtype}, output.dtype={attn_weights.dtype}"
return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
......@@ -246,7 +308,6 @@ class DenseGeneral(nn.Module):
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
......@@ -268,11 +329,14 @@ class DenseGeneral(nn.Module):
contract_ind = tuple(range(0, len(axis)))
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
y = y.astype(input_dtype)
y = lax.dot_general(
inputs, kernel, ((axis, contract_ind), ((), ())), preferred_element_type=input_dtype
)
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
assert y.dtype == inputs.dtype, f"input.dtype={inputs.dtype}, output.dtype={y.dtype}"
return y
......@@ -352,6 +416,7 @@ class MlpBlock(nn.Module):
)(
x, deterministic=deterministic
) # Broadcast along length.
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
else:
......@@ -365,6 +430,7 @@ class MlpBlock(nn.Module):
bias_axes="embed",
name="wo",
)(x)
assert (
output.dtype == inputs.dtype
), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
......@@ -391,7 +457,7 @@ def apply_rotary_pos_emb_alternate(
second_part = second_half * cos + first_half * sin
first_part = first_part.astype(inputs.dtype)
second_part = second_part.astype(inputs.dtype)
return jnp.concatenate([first_part, second_part], axis=-1)
return jnp.concatenate([first_part, second_part], axis=-1).astype(inputs.dtype)
def apply_rotary_pos_emb_consecutive(
......@@ -425,7 +491,7 @@ def apply_rotary_pos_emb_consecutive(
sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5)
outputs = inputs * cos + inputs_shifted * sin * sign
return outputs
return outputs.astype(inputs.dtype)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
......@@ -559,6 +625,7 @@ class MultiHeadAttention(nn.Module):
if self.fuse_qkv:
if is_qkvpack:
qkv_proj = DenseGeneral(
axis=-1,
features=self.num_heads * self.head_dim * 3,
......@@ -569,11 +636,13 @@ class MultiHeadAttention(nn.Module):
name="qkv",
dtype=self.dtype,
)(inputs_kv)
query, key, value = jnp.split(
qkv_proj,
[self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1,
)
else:
query = q_projection(kernel_init=query_init, name="query")(inputs_q)
......@@ -711,6 +780,7 @@ class MultiHeadAttention(nn.Module):
# Convert the boolean attention mask to an attention bias.
if mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, 0.0).astype(self.dtype),
......@@ -740,6 +810,7 @@ class MultiHeadAttention(nn.Module):
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions.
out = DenseGeneral(
features=inputs_q.shape[-1], # output dim is set to the input dim.
axis=-1,
......@@ -750,6 +821,7 @@ class MultiHeadAttention(nn.Module):
dtype=self.dtype,
name="out",
)(x)
assert (
inputs_q.dtype == inputs_kv.dtype == out.dtype
), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
......@@ -784,12 +856,11 @@ class LayerNorm(nn.Module):
scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), self.dtype, axes=("embed",)
)
scale = jnp.asarray(scale, input_dtype)
x_ = x.astype(jnp.float32)
if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
y = (x - mean) * lax.rsqrt(var + self.epsilon)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
......@@ -803,9 +874,10 @@ class LayerNorm(nn.Module):
else:
assert self.layernorm_type == "rmsnorm"
assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = x * lax.rsqrt(mean2 + self.epsilon)
mean2 = jnp.mean(lax.square(x_), axis=-1, keepdims=True)
y = x_ * lax.rsqrt(mean2 + self.epsilon)
z = y * scale
z = z.astype(input_dtype)
assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
return z
......@@ -1085,9 +1157,11 @@ class EncoderLayer(nn.Module):
fuse_wi=self.fuse_mlp_wi,
name="mlp",
)(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
y, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
......@@ -1103,6 +1177,7 @@ class EncoderLayer(nn.Module):
dtype=self.dtype,
name="output_layernorm",
)(y)
assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
return y
......
......@@ -318,7 +318,6 @@ def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
device=device,
with_amax_reduction=True,
amax_reduction_group=tp_group,
amax_reduction_size=tp_size,
)
quantizer = quantizer_class(
fp8_dtype=fp8_dtype,
......
......@@ -741,7 +741,6 @@ def _test_fp8_scale_update(
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
......
......@@ -286,6 +286,12 @@ def run_dpa_with_cp(
else:
out_.backward(dout_)
if fp8_mha:
assert isinstance(out, Float8Tensor)
assert isinstance(out_, Float8Tensor)
out = out.dequantize()
out_ = out_.dequantize()
for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
......
......@@ -229,7 +229,7 @@ def get_model(
attn_mask_type = "causal"
qkv_format = "bshd"
if mode == "inference":
attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding"
attn_mask_type = "padding_causal"
fp8_recipe = recipe.DelayedScaling(
margin=0,
......@@ -392,9 +392,9 @@ def get_tols(module, backend, dtype):
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"])
@pytest.mark.parametrize("is_cuda_graph", [False, True])
@pytest.mark.parametrize("is_fp8", [False, True])
def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8):
def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8):
reset_rng_states()
logger = logging.getLogger("test_paged_attn")
logger = logging.getLogger("test_kv_cache")
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
......@@ -407,7 +407,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda
fp8_meta["recipe"] = fp8_recipe
config = model_configs_infer[model]
num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1
num_layers = 2 if module == "TransformerLayer" else 1
# flash-attn v2 requires page_size >= 256
if backend == "FlashAttention" and not fa_utils.v3_is_installed:
config_max_seqlen_q = config.max_seqlen_q
......@@ -437,7 +437,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda
# initialize inference_params
inference_params = InferenceParams(
max_batch_size=max_batch_size,
max_seqlen_kv=config.max_seqlen_kv,
max_sequence_length=config.max_seqlen_kv,
num_heads_kv=config.num_gqa_groups,
head_dim_k=config.head_dim_qk,
head_dim_v=config.head_dim_v,
......
......@@ -57,6 +57,7 @@ model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
fp8_recipes = [
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
]
# Supported data types
......
......@@ -297,7 +297,6 @@ class TestFuser:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=8,
amax_compute_algo="max",
......
......@@ -56,3 +56,10 @@ def test_torch_dynamo(model_name: str):
# Forward and backward pass
out = model(*inputs)
out.backward(torch.zeros_like(out))
def test_lazy_compile():
"""Smoke test to ensure lazy compilation is working."""
from transformer_engine.pytorch.jit import dgelu_fused_
dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10))
......@@ -2144,7 +2144,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
inference_params = InferenceParams(
max_batch_size=B_max,
max_seqlen_kv=S_max,
max_sequence_length=S_max,
num_heads_kv=H,
head_dim_k=head_size,
dtype=dtype,
......
......@@ -177,7 +177,6 @@ class TestFP8Recipe:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
......
......@@ -110,32 +110,17 @@ model_configs = {
}
fp8_recipes = [
None, # Handles non-FP8 case
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
None, # Test non-FP8
recipe.MXFP8BlockScaling(), # Test default
recipe.Float8CurrentScaling(), # Test default
recipe.DelayedScaling(), # Test default
recipe.DelayedScaling( # Test most_recent algo
amax_history_len=16,
amax_compute_algo="most_recent",
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="max",
),
recipe.DelayedScaling(
margin=0,
recipe.DelayedScaling( # Test custom amax and scale compute algo
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo=custom_amax_compute,
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
scaling_factor_compute_algo=custom_amax_to_scale,
),
]
......@@ -567,6 +552,8 @@ def test_sanity_grouped_linear(
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8():
pytest.skip("Grouped linear does not support MXFP8")
if fp8_recipe.float8_current_scaling():
pytest.skip("Grouped linear does not support FP8 current scaling")
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
......
......@@ -19,9 +19,4 @@ try:
except (ImportError, StopIteration) as e:
pass
try:
import transformer_engine_jax
except ImportError:
pass
__version__ = str(metadata.version("transformer_engine"))
......@@ -233,7 +233,8 @@ if (USE_CUDA)
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart)
CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
......
......@@ -771,6 +771,16 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
}
}
#ifndef __HIP_PLATFORM_AMD__
namespace transformer_engine {
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }
} // namespace transformer_engine
#endif
#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
......
......@@ -140,6 +140,13 @@ constexpr int num_batchgemm_streams = 1;
constexpr int num_streams = 4;
#endif
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
* region. This function is a helper to call cublasCreate() which allocate memory for the handle.
* The function will be called in the initialize phase of the related XLA custom calls.
*/
void nvte_cublas_handle_init();
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_H_
......@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable);
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -80,7 +80,8 @@ enum NVTEScalingMode {
/*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING = 1,
NVTE_INVALID_SCALING
NVTE_INVALID_SCALING = 2,
NVTE_NO_SCALING = 3
};
/*! \brief TE Tensor type
......@@ -346,6 +347,13 @@ enum class DType {
kNumTypes
};
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
*/
bool is_fp8_dtype(const DType t);
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
*/
......
......@@ -11,10 +11,12 @@
transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*;
transformer_engine::nvte_cublas_handle_init*;
transformer_engine::typeToSize*;
transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore*
};
local: *;
};
\ No newline at end of file
};
......@@ -12,6 +12,7 @@
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include <functional>
......@@ -141,7 +142,6 @@ struct BackwardKernelParams : public KernelParamsBase {
};
enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
enum class NVTE_Norm_Stage { Forward, Backward };
using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>;
......
......@@ -162,7 +162,6 @@ class DelayedScaling(Recipe):
"""
margin: int = 0
interval: int = -1
fp8_format: Format = Format.HYBRID
amax_history_len: int = 1024
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max"
......@@ -173,12 +172,6 @@ class DelayedScaling(Recipe):
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
if self.interval >= 0:
warnings.warn(
"`interval` argument is deprecated and unused. "
"It will be removed in an upcoming release.",
DeprecationWarning,
)
def __repr__(self) -> str:
return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment