Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -8,7 +8,6 @@
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
namespace transformer_engine {
namespace jax {
......@@ -83,8 +82,8 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, workspace_tensor.data(), barrier_tensor.data());
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, num_sm,
workspace_tensor.data(), barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
......@@ -122,19 +121,18 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dbeta_tensor.data(), dummy_dgamma_part_tensor.data(),
dummy_dbeta_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(),
xgrad_tensor.data(), wgrad_tensor.data(), dummy_dgamma_part_tensor.data(),
nullptr, num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_shape = std::vector<size_t>{0, 0};
}
......@@ -183,26 +181,22 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto barrier_shape = std::vector<size_t>{barrier_size};
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
auto dgamma_part_tensor =
TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_tensor =
TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dbeta_tensor.data(), dgamma_part_tensor.data(), dbeta_part_tensor.data(),
stream, num_sm, workspace_tensor.data(), barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(),
xgrad_tensor.data(), wgrad_tensor.data(), dgamma_part_tensor.data(), stream,
num_sm, workspace_tensor.data(), barrier_tensor.data());
}
}
......
......@@ -6,7 +6,6 @@
#include "jax/csrc/extensions.h"
namespace transformer_engine {
namespace jax {
......@@ -22,8 +21,7 @@ pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape,
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype,
size_t act_enum) {
DType out_dtype, DType wk_dtype, size_t act_enum) {
CustomCallCommonWkDescriptor desc{};
desc.shape.from_vector(shape);
desc.wkshape.from_vector(wkshape);
......@@ -61,8 +59,8 @@ pybind11::bytes PackCustomCallNormDescriptor(
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t head_dim, size_t q_seqlen, size_t k_seqlen,
DType dtype, float scale_factor) {
return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen,
dtype, scale_factor});
return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen, dtype,
scale_factor});
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
......@@ -72,9 +70,9 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout,
dtype, wkspace_dtype, is_training});
}
} // namespace jax
......
......@@ -49,11 +49,11 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor,
pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor,
pybind11::arg(), pybind11::arg(), pybind11::arg(),
pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, pybind11::arg(), pybind11::arg(),
pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor, pybind11::arg(),
pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg(),
pybind11::arg("act_num") = 0);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
......
......@@ -4,9 +4,9 @@
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/softmax.h"
#include "jax/csrc/extensions.h"
namespace transformer_engine {
namespace jax {
......@@ -23,8 +23,7 @@ void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaqu
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor,
stream);
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor, stream);
}
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
......@@ -52,8 +51,7 @@ void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape =
std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto io_shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
......@@ -62,8 +60,8 @@ void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto output_tensor = TensorWrapper(output, io_shape, dtype);
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(),
output_tensor.data(), desc.scale_factor, stream);
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(),
desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
......@@ -105,11 +103,10 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_backward(
grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
desc.scale_factor, stream);
nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(),
softmax_output_tensor.data(),
dgrad_tensor.data(), desc.scale_factor, stream);
}
} // namespace jax
} // namespace transformer_engine
......@@ -4,9 +4,10 @@
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/transpose.h"
#include "jax/csrc/extensions.h"
namespace transformer_engine {
namespace jax {
......@@ -58,11 +59,11 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto input_cast_tensor =
TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape,
desc.out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape, desc.out_dtype,
amax_out, scale, scale_inv);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(),
input_cast_trans_tensor.data(), stream);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(),
stream);
}
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
......@@ -79,9 +80,8 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
......@@ -122,9 +122,8 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
}
} // namespace jax
......
......@@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime_api.h>
#include <cassert>
#include "common/util/cuda_runtime.h"
......
......@@ -7,16 +7,16 @@
#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#include <pybind11/pybind11.h>
#include <transformer_engine/fused_attn.h>
#include <cstdint>
#include <numeric>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <pybind11/pybind11.h>
#include "common/util/logging.h"
#include <transformer_engine/fused_attn.h>
namespace transformer_engine {
namespace jax {
......
......@@ -18,7 +18,7 @@ def type_safe_dot_general(
x,
kernel,
fp8_meta_pkg: FP8MetaPackage = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,))
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
) -> jnp.ndarray:
"""
Type safe dot_general, including FP8.
......@@ -62,7 +62,8 @@ def fp8_dot_impl(
rhs_scale_inv: jnp.ndarray,
ctype: jnp.dtype, # computing type
contracting_dims: Tuple[Sequence[int], Sequence[int]],
precision: Precision = None):
precision: Precision = None,
):
"""
FP8 GEMM for XLA pattern match
"""
......@@ -82,11 +83,18 @@ def get_precision_of_fp8_dot(enable_2xACC: bool):
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6))
def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, amax_list: List[jnp.ndarray],
scale_list: List[jnp.ndarray], fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
output, _ = _fp8_dot_fwd_rule(x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype,
contracting_dims)
def _fp8_dot(
x: jnp.ndarray,
kernel: jnp.ndarray,
amax_list: List[jnp.ndarray],
scale_list: List[jnp.ndarray],
fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
):
output, _ = _fp8_dot_fwd_rule(
x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype, contracting_dims
)
return output
......@@ -97,22 +105,25 @@ def _fp8_dot_fwd_rule(
scale_list,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims):
contracting_dims,
):
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list, *scale_list)
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair(
*amax_list, *scale_list
)
amax_list = maybe_fm32_to_fp32(*amax_list)
scale_list = maybe_fm32_to_fp32(*scale_list)
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
x_shape_suf = x.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
x_shape_suf = x.shape[min(lhs_contracting_dims) :]
kernel_shape_pre = kernel.shape[: max(rhs_contracting_dims) + 1]
assert x_shape_suf == kernel_shape_pre
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(amax_list, scale_list,
fp8_dtype_list)
scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(
amax_list, scale_list, fp8_dtype_list
)
amax_list = FP8MetaPackage.update_amax_list(amax_list)
x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
......@@ -127,52 +138,100 @@ def _fp8_dot_fwd_rule(
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)
output = fp8_dot_impl(casted_x, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
output = fp8_dot_impl(
casted_x,
casted_kernel,
x_scale_inv,
kernel_scale_inv,
x.dtype,
(lhs_contracting_dims, rhs_contracting_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
)
ctx = (casted_x, casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape, maybe_fp32_to_fm32)
ctx = (
casted_x,
casted_kernel,
amax_list,
scale_list,
scale_inv_list,
updated_x_amax,
updated_kernel_amax,
x.shape,
kernel.shape,
maybe_fp32_to_fm32,
)
return output, ctx
def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument
def _fp8_dot_bwd_rule(
fwd_dtype, bwd_dtype, contracting_dims, ctx, grad
): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
casted_x, casted_kernel, amax_list, scale_list, scale_inv_list, \
updated_x_amax, updated_kernel_amax, x_shape, kernel_shape, \
maybe_fp32_to_fm32 = ctx
(
casted_x,
casted_kernel,
amax_list,
scale_list,
scale_inv_list,
updated_x_amax,
updated_kernel_amax,
x_shape,
kernel_shape,
maybe_fp32_to_fm32,
) = ctx
grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
casted_grad, casted_grad_t, updated_grad_amax = \
tex.cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
grad,
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims),
)
x_constracting_dim = tuple(range(0, len(x_shape) - len(lhs_contracting_dims)))
gt_constracting_dim = tuple(range(grad.ndim - len(x_constracting_dim), grad.ndim))
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
wgrad = fp8_dot_impl(casted_x, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
wgrad = fp8_dot_impl(
casted_x,
casted_grad_t,
x_scale_inv,
grad_scale_inv,
grad.dtype,
(x_constracting_dim, gt_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
)
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim)
)
k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
dgrad = fp8_dot_impl(
casted_grad,
casted_kernel,
grad_scale_inv,
kernel_scale_inv,
grad.dtype,
(g_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
)
amax_list[FP8MetaPackage.INPUT_IDX] = \
amax_list[FP8MetaPackage.INPUT_IDX] = (
amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax)
amax_list[FP8MetaPackage.WEIGHT_IDX] = \
)
amax_list[FP8MetaPackage.WEIGHT_IDX] = (
amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax)
amax_list[FP8MetaPackage.GRAD_IDX] = \
)
amax_list[FP8MetaPackage.GRAD_IDX] = (
amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
)
amax_list = maybe_fp32_to_fm32(*amax_list)
scale_list = maybe_fp32_to_fm32(*scale_list)
......
......@@ -9,15 +9,15 @@ from .transformer import DotProductAttention, MultiHeadAttention, RelativePositi
from .transformer import TransformerLayer, TransformerLayerType
__all__ = [
'DenseGeneral',
'LayerNorm',
'LayerNormDenseGeneral',
'LayerNormMLP',
'TransformerEngineBase',
'extend_logical_axis_rules',
'DotProductAttention',
'MultiHeadAttention',
'RelativePositionBiases',
'TransformerLayer',
'TransformerLayerType',
"DenseGeneral",
"LayerNorm",
"LayerNormDenseGeneral",
"LayerNormMLP",
"TransformerEngineBase",
"extend_logical_axis_rules",
"DotProductAttention",
"MultiHeadAttention",
"RelativePositionBiases",
"TransformerLayer",
"TransformerLayerType",
]
......@@ -30,8 +30,9 @@ PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
lax.Precision]]
PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
......@@ -55,25 +56,22 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
return nn.initializers.zeros
def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init,
bias_axes, dtype):
scale = nn_partitioning.param_with_axes('scale',
scale_init,
shape,
jnp.float32,
axes=scale_axes)
def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype
):
scale = nn_partitioning.param_with_axes(
"scale", scale_init, shape, jnp.float32, axes=scale_axes
)
scale = jnp.asarray(scale, dtype)
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'layernorm':
bias = nn_partitioning.param_with_axes('ln_bias',
bias_init,
shape,
jnp.float32,
axes=bias_axes)
if layernorm_type == "layernorm":
bias = nn_partitioning.param_with_axes(
"ln_bias", bias_init, shape, jnp.float32, axes=bias_axes
)
bias = jnp.asarray(bias, dtype)
else:
assert layernorm_type == 'rmsnorm'
assert layernorm_type == "rmsnorm"
bias = None
return scale, bias
......@@ -81,7 +79,7 @@ def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes,
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
"""Convert a string to an activation function."""
if fn_or_string == 'linear':
if fn_or_string == "linear":
return lambda x: x
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
......@@ -96,8 +94,9 @@ def _combine_biases(*masks: List[Array]):
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(map(lambda x: x.ndim == masks[0].ndim,
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
assert all(
map(lambda x: x.ndim == masks[0].ndim, masks)
), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
mask, *other_masks = masks
for other_mask in other_masks:
mask = mask + other_mask
......@@ -108,10 +107,10 @@ def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel,
"""Low Rank Adaptation Implementation"""
assert len(axis) <= 5
hidden_in_names = 'ijklm'[:len(axis)]
hidden_in_names = "ijklm"[: len(axis)]
assert len(features) <= 5
hidden_out_names = 'nopqr'[:len(features)]
rank_name = 's'
hidden_out_names = "nopqr"[: len(features)]
rank_name = "s"
assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
rank = lora_a_kernel.shape[-1]
......@@ -121,8 +120,10 @@ def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel,
lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
output_einsum_express = f"...{hidden_out_names}"
final_einsum_express = f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}" \
final_einsum_express = (
f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
f"->{output_einsum_express}"
)
output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
output = output * scaling
......@@ -160,8 +161,9 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
dtype = inputs.dtype
logits = inputs
if (self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype)):
if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype
):
if bias is not None:
logits = logits + bias.astype(dtype)
......@@ -174,9 +176,11 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
else:
attention_bias = None
if mask is not None:
attention_bias = lax.select(mask > 0,
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.).astype(dtype))
jnp.full(mask.shape, 0.0).astype(dtype),
)
if bias is not None:
attention_bias = _combine_biases(attention_bias, bias)
......@@ -186,8 +190,9 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
# and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen,
dtype):
if is_softmax_kernel_available(
SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype
):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
else:
outputs = jax_nn.softmax(logits * self.scale_factor)
......@@ -262,19 +267,21 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
"""
epsilon: float = 1e-6
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False
scale_init: Initializer = None
scale_axes: Tuple[str, ...] = ('embed',)
scale_axes: Tuple[str, ...] = ("embed",)
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ('embed',)
bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma)
self.scale_init, self.zero_centered_gamma
)
super().__post_init__()
@nn.compact
......@@ -294,15 +301,23 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
"""
features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
self.scale_init, self.scale_axes,
self.bias_init, self.bias_axes, self.dtype)
return layernorm(x,
scale, ln_bias = _create_layernorm_parameters(
self.layernorm_type,
(features,),
self.scale_init,
self.scale_axes,
self.bias_init,
self.bias_axes,
self.dtype,
)
return layernorm(
x,
scale,
ln_bias,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
epsilon=self.epsilon,
)
class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods
......@@ -321,18 +336,23 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-meth
grad_name_post_fix = f"_g_{postfix}"
def generate_a_set(target_postfix):
amax = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
amax = nn_partitioning.variable_with_axes(
FP8Helper.FP8_COLLECTION_NAME,
f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}",
jnp.zeros, (FP8Helper.AMAX_HISTORY_LEN,),
jnp.zeros,
(FP8Helper.AMAX_HISTORY_LEN,),
jnp.float32,
axes=(None,))
axes=(None,),
)
scale = nn_partitioning.variable_with_axes(
FP8Helper.FP8_COLLECTION_NAME,
f"{FP8Helper.FP8_SCALE_NAME}{target_postfix}",
jnp.ones, (1,),
jnp.ones,
(1,),
jnp.float32,
axes=(None,))
axes=(None,),
)
return amax.value, scale.value
......@@ -340,8 +360,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-meth
weight_amax, weight_scale = generate_a_set(weight_name_post_fix)
grad_amax, grad_scale = generate_a_set(grad_name_post_fix)
return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax,
grad_scale)
return FP8MetaPackage(
input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale
)
class DenseGeneral(TransformerEngineBase):
......@@ -403,7 +424,7 @@ class DenseGeneral(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
super().__post_init__()
@nn.compact
......@@ -430,20 +451,16 @@ class DenseGeneral(TransformerEngineBase):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes('kernel',
self.kernel_init,
kernel_param_shape,
jnp.float32,
axes=self.kernel_axes)
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes('bias',
self.bias_init,
features,
jnp.float32,
axes=self.bias_axes)
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, jnp.float32, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
else:
bias = None
......@@ -453,36 +470,46 @@ class DenseGeneral(TransformerEngineBase):
if FP8Helper.is_fp8_enabled():
fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
y = type_safe_dot_general(inputs,
kernel,
fp8_meta_pkg=fp8_meta_pkg,
contracting_dims=(axis, contract_ind))
y = type_safe_dot_general(
inputs, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
)
if self.enable_low_rank_adaptation:
lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_shape = (
*kernel_shape[: len(axis)],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_init_shape = (
kernel_param_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel',
lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
jnp.float32,
axes=lora_a_kernel_axes)
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel',
lora_b_kernel = nn_partitioning.param_with_axes(
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
jnp.float32,
axes=lora_b_kernel_axes)
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
y += _apply_low_rank_adaptation(inputs, axis, features, lora_a_kernel, lora_b_kernel,
self.low_rank_adaptation_alpha)
y += _apply_low_rank_adaptation(
inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
)
if bias is not None:
bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
......@@ -581,13 +608,13 @@ class LayerNormDenseGeneral(TransformerEngineBase):
features: Union[Iterable[int], int]
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: Initializer = None
scale_axes: Tuple[str, ...] = ('embed',)
scale_axes: Tuple[str, ...] = ("embed",)
ln_bias_init: Initializer = nn.initializers.zeros
ln_bias_axes: Tuple[str, ...] = ('embed',)
ln_bias_axes: Tuple[str, ...] = ("embed",)
kernel_init: Initializer = None
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
......@@ -606,9 +633,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma)
self.scale_init, self.zero_centered_gamma
)
super().__post_init__()
@nn.compact
......@@ -632,8 +660,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
ln_output = None
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
fuse_layernorm = (
FP8Helper.is_fp8_enabled()
and not self.return_layernorm_output
and self.enable_layernorm
)
if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
......@@ -641,18 +672,25 @@ class LayerNormDenseGeneral(TransformerEngineBase):
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
self.scale_init, self.scale_axes,
self.ln_bias_init, self.ln_bias_axes,
self.dtype)
scale, ln_bias = _create_layernorm_parameters(
self.layernorm_type,
(features,),
self.scale_init,
self.scale_axes,
self.ln_bias_init,
self.ln_bias_axes,
self.dtype,
)
if not fuse_layernorm:
y = layernorm(inputs,
y = layernorm(
inputs,
scale,
ln_bias,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
epsilon=self.epsilon,
)
else:
assert not self.return_layernorm_output
y = inputs
......@@ -670,11 +708,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes('kernel',
self.kernel_init,
kernel_param_shape,
jnp.float32,
axes=self.kernel_axes)
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
)
kernel = jnp.reshape(kernel, kernel_shape)
......@@ -685,7 +721,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
if fuse_layernorm:
z = layernorm_fp8_dot(y,
z = layernorm_fp8_dot(
y,
kernel,
scale,
ln_bias,
......@@ -694,47 +731,56 @@ class LayerNormDenseGeneral(TransformerEngineBase):
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes)
dot_input_axes=self.dot_input_axes,
)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = type_safe_dot_general(y,
kernel,
fp8_meta_pkg=fp8_meta_pkg,
contracting_dims=(axis, contract_ind))
z = type_safe_dot_general(
y, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
)
if self.enable_low_rank_adaptation:
lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_shape = (
*kernel_shape[: len(axis)],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_init_shape = (
kernel_param_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel',
lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
jnp.float32,
axes=lora_a_kernel_axes)
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel',
lora_b_kernel = nn_partitioning.param_with_axes(
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
jnp.float32,
axes=lora_b_kernel_axes)
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
z += _apply_low_rank_adaptation(y, axis, features, lora_a_kernel, lora_b_kernel,
self.low_rank_adaptation_alpha)
z += _apply_low_rank_adaptation(
y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
)
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('bias',
self.bias_init,
features,
jnp.float32,
axes=self.bias_axes)
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, jnp.float32, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
if bias is not None:
......@@ -858,23 +904,23 @@ class LayerNormMLP(TransformerEngineBase):
intermediate_dim: int = 2048
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: Initializer = None
scale_axes: Tuple[str, ...] = ('embed',)
scale_axes: Tuple[str, ...] = ("embed",)
ln_bias_init: Initializer = nn.initializers.zeros
ln_bias_axes: Tuple[str, ...] = ('embed',)
ln_bias_axes: Tuple[str, ...] = ("embed",)
kernel_init: Initializer = None
kernel_axes_1: Tuple[str, ...] = ('embed', 'act', 'mlp')
kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ('act', 'mlp')
bias_axes_2: Tuple[str, ...] = ('embed',)
bias_axes_1: Tuple[str, ...] = ("act", "mlp")
bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rng_name: str = 'dropout'
activations: Sequence[Union[str, Callable]] = ("relu",)
intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
enable_low_rank_adaptation: bool = False
......@@ -889,9 +935,10 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma)
self.scale_init, self.zero_centered_gamma
)
super().__post_init__()
@nn.compact
......@@ -917,24 +964,34 @@ class LayerNormMLP(TransformerEngineBase):
ln_output = None
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
gated_act_pool = [('gelu', 'linear'), ('silu', 'linear'), ('relu', 'linear'),
('quick_gelu', 'linear'), ('squared_relu', 'linear')]
act_pool = [('gelu',), ('silu',), ('relu',), ('quick_gelu',), ('squared_relu',)]
fuse_layernorm = (
FP8Helper.is_fp8_enabled()
and not self.return_layernorm_output
and self.enable_layernorm
)
gated_act_pool = [
("gelu", "linear"),
("silu", "linear"),
("relu", "linear"),
("quick_gelu", "linear"),
("squared_relu", "linear"),
]
act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
normalized_acts = []
for act in self.activations:
if not isinstance(act, str):
return False
normalized_acts.append(act.lower())
normalized_acts = tuple(
reversed(normalized_acts) if normalized_acts[0] == 'linear' else normalized_acts)
reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
)
is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
self.intermediate_dropout_rate < 1e-3
use_fused_layernorm_mlp = (
fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
)
# LayerNorm
if self.enable_layernorm:
......@@ -943,18 +1000,25 @@ class LayerNormMLP(TransformerEngineBase):
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
self.scale_init, self.scale_axes,
self.ln_bias_init, self.ln_bias_axes,
self.dtype)
scale, ln_bias = _create_layernorm_parameters(
self.layernorm_type,
(features,),
self.scale_init,
self.scale_axes,
self.ln_bias_init,
self.ln_bias_axes,
self.dtype,
)
if not fuse_layernorm:
y = layernorm(inputs,
y = layernorm(
inputs,
scale,
ln_bias,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
epsilon=self.epsilon,
)
else:
assert not self.return_layernorm_output
y = inputs
......@@ -984,55 +1048,58 @@ class LayerNormMLP(TransformerEngineBase):
intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
kernel_1 = nn_partitioning.param_with_axes(
"wi_kernel",
kernel_1_init,
num_activations,
-2,
kernel_1_each_shape,
jnp.float32,
axes=self.kernel_axes_1)
axes=self.kernel_axes_1,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
kernel_2 = nn_partitioning.param_with_axes('wo_kernel',
kernel_2 = nn_partitioning.param_with_axes(
"wo_kernel",
self.kernel_init,
kernel_2_param_shape,
jnp.float32,
axes=self.kernel_axes_2)
axes=self.kernel_axes_2,
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
contract_ind = tuple(range(0, len(axis)))
ffn1_ckpt_name = 'ffn1'
ffn2_ckpt_name = 'ffn2'
ffn1_ckpt_name = "ffn1"
ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
if self.use_bias:
bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
bias_1_shape,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, bias_1_shape, jnp.float32, axes=self.bias_axes_1
)
bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init,
bias_2_shape,
jnp.float32,
axes=self.bias_axes_2)
bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, bias_2_shape, jnp.float32, axes=self.bias_axes_2
)
bias_2 = bias_2.astype(self.dtype)
else:
bias_1 = None
bias_2 = None
out = fused_layernorm_fp8_mlp(y,
out = fused_layernorm_fp8_mlp(
y,
scale,
ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
ln_bias,
[kernel_1, kernel_2],
[bias_1, bias_2],
[wi_fp8_meta_pkg, wo_fp8_meta_pkg],
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
......@@ -1043,12 +1110,14 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name,
activation_type=normalized_acts,
use_bias=self.use_bias)
use_bias=self.use_bias,
)
else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1
if fuse_layernorm:
x = layernorm_fp8_dot(y,
x = layernorm_fp8_dot(
y,
kernel_1,
scale,
ln_bias,
......@@ -1057,52 +1126,71 @@ class LayerNormMLP(TransformerEngineBase):
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes)
dot_input_axes=self.dot_1_input_axes,
)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
x = type_safe_dot_general(y,
kernel_1,
fp8_meta_pkg=wi_fp8_meta_pkg,
contracting_dims=(axis, contract_ind))
x = type_safe_dot_general(
y, kernel_1, fp8_meta_pkg=wi_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
)
if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = (*kernel_1_shape[:len(axis)], num_activations,
self.low_rank_adaptation_dim)
wi_lora_a_kernel_init_shape = (kernel_1_each_shape[0], num_activations,
self.low_rank_adaptation_dim)
wi_lora_a_kernel_init_each_shape = (kernel_1_each_shape[0],
self.low_rank_adaptation_dim)
wi_lora_a_kernel_shape = (
*kernel_1_shape[: len(axis)],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_shape = (
kernel_1_each_shape[0],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_each_shape = (
kernel_1_each_shape[0],
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
wi_lora_a_kernel = nn_partitioning.param_with_axes('wi_lora_a_kernel',
wi_lora_a_kernel = nn_partitioning.param_with_axes(
"wi_lora_a_kernel",
kernel_1_init,
num_activations,
-2,
wi_lora_a_kernel_init_each_shape,
jnp.float32,
axes=wi_lora_a_kernel_axes)
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype)
wi_lora_b_kernel_shape = (num_activations, self.low_rank_adaptation_dim,
self.intermediate_dim)
wi_lora_b_kernel_shape = (
num_activations,
self.low_rank_adaptation_dim,
self.intermediate_dim,
)
wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
wi_lora_b_kernel = nn_partitioning.param_with_axes('wi_lora_b_kernel',
wi_lora_b_kernel = nn_partitioning.param_with_axes(
"wi_lora_b_kernel",
nn.initializers.zeros,
wi_lora_b_kernel_shape,
jnp.float32,
axes=wi_lora_b_kernel_axes)
axes=wi_lora_b_kernel_axes,
)
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel,
wi_lora_b_kernel, self.low_rank_adaptation_alpha)
x += _apply_low_rank_adaptation(
y,
axis,
intermediate_dim,
wi_lora_a_kernel,
wi_lora_b_kernel,
self.low_rank_adaptation_alpha,
)
bias_1 = None
if self.use_bias:
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, intermediate_dim, jnp.float32, axes=self.bias_axes_1
)
bias_1 = bias_1.astype(self.dtype)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
x += jnp.reshape(bias_1, bias_1_shape)
......@@ -1120,47 +1208,56 @@ class LayerNormMLP(TransformerEngineBase):
# Remove act axis
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate,
z = nn.Dropout(
rate=self.intermediate_dropout_rate,
broadcast_dims=self.intermediate_hidden_dropout_dims,
rng_collection=self.intermediate_dropout_rng_name)(
z, deterministic=deterministic)
rng_collection=self.intermediate_dropout_rng_name,
)(z, deterministic=deterministic)
z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
# DenseGeneral 2
out = type_safe_dot_general(z,
kernel_2,
fp8_meta_pkg=wo_fp8_meta_pkg,
contracting_dims=(axis, contract_ind))
out = type_safe_dot_general(
z, kernel_2, fp8_meta_pkg=wo_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
)
if self.enable_low_rank_adaptation:
wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
wo_lora_a_kernel = nn_partitioning.param_with_axes('wo_lora_a_kernel',
wo_lora_a_kernel = nn_partitioning.param_with_axes(
"wo_lora_a_kernel",
self.kernel_init,
wo_lora_a_kernel_shape,
jnp.float32,
axes=wo_lora_a_kernel_axes)
axes=wo_lora_a_kernel_axes,
)
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
wo_lora_b_kernel = nn_partitioning.param_with_axes('wo_lora_b_kernel',
wo_lora_b_kernel = nn_partitioning.param_with_axes(
"wo_lora_b_kernel",
nn.initializers.zeros,
wo_lora_b_kernel_shape,
jnp.float32,
axes=wo_lora_b_kernel_axes)
axes=wo_lora_b_kernel_axes,
)
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel,
wo_lora_b_kernel, self.low_rank_adaptation_alpha)
out += _apply_low_rank_adaptation(
z,
axis,
hidden_size_tuple,
wo_lora_a_kernel,
wo_lora_b_kernel,
self.low_rank_adaptation_alpha,
)
bias_2 = None
if self.use_bias:
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
jnp.float32,
axes=self.bias_axes_2)
bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, (hidden_size,), jnp.float32, axes=self.bias_axes_2
)
bias_2 = bias_2.astype(self.dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
......
......@@ -39,8 +39,9 @@ PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
lax.Precision]]
PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]
......@@ -82,14 +83,13 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
"""
rules_map = {}
for item in rules:
assert len(item) == 2, \
"The logical axis rule should be like (axis_name, mesh_axis_name)."
assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
key = item[0]
val = item[1]
assert isinstance(key, str), \
f"Thie axis_name should be str, but got {type(key)}."
assert isinstance(val, str) or (val is None), \
f"Thie mesh_axis_name should be str or None, but got {type(val)}."
assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
assert isinstance(val, str) or (
val is None
), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
if key in rules_map:
rules_map[key].append(val)
else:
......@@ -100,17 +100,18 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
key = item[0]
val = item[1]
if key in rules_map:
assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \
f"The rule diverged between TE and given rule." \
f"Axis:{key} map to {rules_map[key]} in the given" \
assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
"The rule diverged between TE and given rule."
f"Axis:{key} map to {rules_map[key]} in the given"
f" rules, but {val} in TE's rules."
)
else:
extended_rules.append(item)
return tuple(extended_rules)
class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float = 0.
attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
......@@ -119,7 +120,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
transpose_batch_sequence: bool = True
@nn.compact
def __call__(self,
def __call__(
self,
query: Array,
key: Array,
value: Array,
......@@ -127,15 +129,17 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
bias: Optional[Array] = None,
*,
dropout_rng: Optional[PRNGKey] = None,
deterministic: bool = False) -> Array:
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
deterministic: bool = False,
) -> Array:
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
batch_dim = 1 if self.transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
assert (
query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
), "q, k, v batch dims must match."
sequence_dim = 0 if self.transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_attention_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1])
......@@ -149,7 +153,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths.
is_gqa = (h_q != h_kv)
is_gqa = h_q != h_kv
if is_gqa:
assert (h_q % h_kv == 0) and (h_q >= h_kv)
......@@ -158,16 +162,16 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
if self.transpose_batch_sequence:
if is_gqa:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
else:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
else:
if is_gqa:
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
attn_weights = checkpoint_name(attn_weights, 'logits')
attn_weights = checkpoint_name(attn_weights, "logits")
if is_gqa:
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
......@@ -175,13 +179,14 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
attn_weights = with_sharding_constraint_by_logical_axes(
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)
)
# When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
# In this case, the scale can not fused into the Softmax module.
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
attn_weights = attn_weights * scale_factor
fused_scale_factor = 1.
fused_scale_factor = 1.0
else:
# If not post_scale_bias, the scale can be fused into Softmax module
fused_scale_factor = scale_factor
......@@ -199,39 +204,40 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
if mask is not None:
return SoftmaxType.SCALED_MASKED, mask
return SoftmaxType.SCALED, mask
raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
)
softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask,
bias).astype(self.dtype)
attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
attn_weights, mask, bias
).astype(self.dtype)
if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
if not deterministic and self.attention_dropout > 0.:
if not deterministic and self.attention_dropout > 0.0:
keep_prob = 1.0 - self.attention_dropout
dropout_shape = list(attn_weights.shape)
# TODO(rewang): add attention dropout broadcast dimension arguments for users
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = (keep.astype(attn_weights.dtype) /
jnp.asarray(keep_prob, dtype=self.dtype))
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
attn_weights = attn_weights * multiplier
if self.transpose_batch_sequence:
if is_gqa:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
if is_gqa:
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float = 0.
attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
......@@ -240,7 +246,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
transpose_batch_sequence: bool = False
@nn.compact
def __call__(self,
def __call__(
self,
query: Array,
key: Array,
value: Array,
......@@ -248,7 +255,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
bias: Optional[Array] = None,
*,
dropout_rng: Optional[PRNGKey] = None,
deterministic: bool = False) -> Array:
deterministic: bool = False,
) -> Array:
seed = None
if dropout_rng is not None:
......@@ -269,7 +277,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
qkv_packed = query
if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_qkvpacked(qkv_packed,
x = fused_attn_qkvpacked(
qkv_packed,
bias,
mask,
seed,
......@@ -277,7 +286,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic)
is_training=not deterministic,
)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat
query: query tensor, shape = [..., h, d]
......@@ -288,7 +298,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_kvpacked(query,
x = fused_attn_kvpacked(
query,
kv_packed,
bias,
mask,
......@@ -297,13 +308,15 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic)
is_training=not deterministic,
)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
key = key.transpose([1, 0, 2, 3])
value = value.transpose([1, 0, 2, 3])
x = fused_attn(query,
x = fused_attn(
query,
key,
value,
bias,
......@@ -313,7 +326,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic)
is_training=not deterministic,
)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
......@@ -423,28 +437,31 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
"""
head_dim: int
num_attention_heads: int
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = 'causal'
attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32
dropout_rng_name: str = 'dropout'
dropout_rng_name: str = "dropout"
float32_logits: bool = False
qkv_layout: str = 'bshd_bshd_bshd'
qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
@nn.compact
def __call__(self,
def __call__(
self,
query: Array,
key: Array,
value: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
deterministic: bool = False) -> Array:
deterministic: bool = False,
) -> Array:
"""
Parameters
----------
......@@ -494,25 +511,34 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
else:
seqlen_kv = key.shape[sequence_dim]
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type,
has_fused_attn_kernel = is_fused_attn_kernel_available(
self.dtype,
self.dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
self.attention_dropout,
self.num_attention_heads,
self.num_gqa_groups, seqlen_q,
seqlen_kv, self.head_dim)
self.num_gqa_groups,
seqlen_q,
seqlen_kv,
self.head_dim,
)
use_fused_attn = (enable_fused_attn and has_fused_attn_kernel)
use_fused_attn = enable_fused_attn and has_fused_attn_kernel
if enable_fused_attn and not has_fused_attn_kernel:
warnings.warn("Fused attention is not enabled because there is no available kernel.\n"
warnings.warn(
"Fused attention is not enabled because there is no available kernel.\n"
"Fall back to the unfused attention.\n"
"Please try to update the cuDNN and TE to the latest version.\n"
f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n")
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n"
)
dropout_rng = None
if not deterministic and self.attention_dropout > 0.:
if not deterministic and self.attention_dropout > 0.0:
dropout_rng = self.make_rng(self.dropout_rng_name)
if self.scale_factor is None:
......@@ -525,28 +551,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
# unfused attention only supports splitted query, key, value
if qkv_layout == QKVLayout.BS3HD:
query, key, value = jnp.split(query, [1, 2], axis=-3)
query, key, value = map(functools.partial(jnp.squeeze, axis=-3),
[query, key, value])
query, key, value = map(
functools.partial(jnp.squeeze, axis=-3), [query, key, value]
)
elif qkv_layout == QKVLayout.BSHD_BS2HD:
key, value = jnp.split(key, [1], axis=-3)
key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
x = _UnfusedDotProductAttention(attention_dropout=self.attention_dropout,
x = _UnfusedDotProductAttention(
attention_dropout=self.attention_dropout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
float32_logits=self.float32_logits,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)(
query,
key,
value,
mask,
bias,
dropout_rng=dropout_rng,
deterministic=deterministic)
transpose_batch_sequence=self.transpose_batch_sequence,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
else:
x = _FusedDotProductAttention(
attention_dropout=self.attention_dropout,
......@@ -561,10 +583,12 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
return x
def rotary_pos_emb(x: Array,
def rotary_pos_emb(
x: Array,
windows: Tuple[int, int],
transpose_batch_sequence: bool,
group_method: str = 'consecutive'):
group_method: str = "consecutive",
):
"""
Rotary Positional Embedding
x should be in shape of
......@@ -577,7 +601,7 @@ def rotary_pos_emb(x: Array,
max_window = windows[1]
fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
time_scales = min_window * (max_window / min_window)**fraction
time_scales = min_window * (max_window / min_window) ** fraction
time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))
batch_dim = 1 if transpose_batch_sequence else 0
......@@ -623,16 +647,17 @@ def rotary_pos_emb(x: Array,
return output
def canonicalize_group_method(gm):
canonicalized_gm = gm.lower().strip().replace('-', '').replace('_', '')
assert canonicalized_gm in ['consecutive', 'alternate'], \
f"Invalid relative positional embedding group method. " \
canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
assert canonicalized_gm in ["consecutive", "alternate"], (
"Invalid relative positional embedding group method. "
f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
)
return canonicalized_gm
group_method = canonicalize_group_method(group_method)
if group_method == 'alternate':
if group_method == "alternate":
return alternate_impl()
return consecutive_impl()
......@@ -646,28 +671,37 @@ class LoRAScope: # pylint: disable=too-few-public-methods
self.mlp = mlp
def __eq__(self, other):
return (self.qkv_proj, self.output_proj, self.mlp) == \
(other.qkv_proj, other.output_proj, other.mlp)
return (self.qkv_proj, self.output_proj, self.mlp) == (
other.qkv_proj,
other.output_proj,
other.mlp,
)
def _canonicalize_lora_scope(scope):
SCOPE_NONE = 'none'
SCOPE_ALL = 'all'
SCOPE_QKV_PROJ = 'qkv_proj'
SCOPE_OUTPUT_PROJ = 'output_proj'
SCOPE_MLP = 'mlp'
SCOPE_EX_QKV_PROJ = 'exclude_qkv_proj'
SCOPE_EX_OUTPUT_PROJ = 'exclude_output_proj'
SCOPE_EX_MLP = 'exclude_mlp'
SCOPE_NONE = "none"
SCOPE_ALL = "all"
SCOPE_QKV_PROJ = "qkv_proj"
SCOPE_OUTPUT_PROJ = "output_proj"
SCOPE_MLP = "mlp"
SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
SCOPE_EX_MLP = "exclude_mlp"
scope = SCOPE_NONE if scope is None else scope
scope = scope.lower()
assert scope in [
SCOPE_NONE, SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_OUTPUT_PROJ, SCOPE_MLP, SCOPE_EX_QKV_PROJ,
SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP
SCOPE_NONE,
SCOPE_ALL,
SCOPE_QKV_PROJ,
SCOPE_OUTPUT_PROJ,
SCOPE_MLP,
SCOPE_EX_QKV_PROJ,
SCOPE_EX_OUTPUT_PROJ,
SCOPE_EX_MLP,
]
lora_scope = LoRAScope()
......@@ -818,8 +852,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
head_dim: int
num_attention_heads: int
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
dropout_rng_name: str = 'dropout'
attention_dropout: float = 0.0
dropout_rng_name: str = "dropout"
input_layernorm: bool = True
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
......@@ -828,12 +862,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init: Initializer = None
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
attn_mask_type: str = 'causal'
attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
......@@ -857,40 +891,50 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
f"Please uses {__class__}.num_attention_heads as the new API.",
DeprecationWarning,
)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
f"Please use {__class__}.attention_dropout as the new API.",
DeprecationWarning,
)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning)
DeprecationWarning,
)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
f"Please use {__class__}.fuse_qkv_params as the new API.",
DeprecationWarning,
)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(self,
def __call__(
self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
decode: bool = False,
deterministic: bool = False) -> Array:
deterministic: bool = False,
) -> Array:
"""
MultiHeadAttention Layer:
[Query, Key, Value projection] -> Dot Product Attention -> Output projection.
......@@ -963,12 +1007,14 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
return tuple(axes)
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_attention_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
is_self_attn = inputs_q is inputs_kv
is_gqa = self.num_attention_heads != self.num_gqa_groups
is_qkvpack = is_self_attn and not is_gqa
inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes(
self.enable_sequence_parallel), HIDDEN_AXES)
inputs_logical_axes_maybe_sp = (
*generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
HIDDEN_AXES,
)
inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)
inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)
......@@ -998,9 +1044,10 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='qkv',
dtype=self.dtype)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
name="qkv",
dtype=self.dtype,
)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
qkv_layout = QKVLayout.BS3HD
else:
query, ln_out = LayerNormDenseGeneral(
......@@ -1025,13 +1072,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='query')(inputs_q)
name="query",
)(inputs_q)
if is_self_attn:
assert ln_out is not None
inputs_kv = ln_out
kv_proj = DenseGeneral(axis=-1,
kv_proj = DenseGeneral(
axis=-1,
features=(2, self.num_gqa_groups * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
......@@ -1042,9 +1091,10 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
name='kv',
dtype=self.dtype)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
name="kv",
dtype=self.dtype,
)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
qkv_layout = QKVLayout.BSHD_BS2HD
else:
kv_projection = functools.partial(
......@@ -1059,7 +1109,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype)
dtype=self.dtype,
)
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type,
......@@ -1082,17 +1133,18 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='query')(inputs_q)
name="query",
)(inputs_q)
if is_self_attn:
assert ln_out is not None
inputs_kv = ln_out
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj')
key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
query = checkpoint_name(query, "query_proj")
key = checkpoint_name(key, "key_proj")
value = checkpoint_name(value, "value_proj")
qkv_layout = QKVLayout.BSHD_BSHD_BSHD
if self.enable_rotary_pos_emb:
......@@ -1107,10 +1159,18 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
query = rotary_pos_emb(query, self.rotary_pos_emb_windows,
self.transpose_batch_sequence, self.rotary_pos_emb_group_method)
key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence,
self.rotary_pos_emb_group_method)
query = rotary_pos_emb(
query,
self.rotary_pos_emb_windows,
self.transpose_batch_sequence,
self.rotary_pos_emb_group_method,
)
key = rotary_pos_emb(
key,
self.rotary_pos_emb_windows,
self.transpose_batch_sequence,
self.rotary_pos_emb_group_method,
)
qkv_layout = QKVLayout.BSHD_BSHD_BSHD
if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
......@@ -1120,13 +1180,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
if decode:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
is_initialized = self.has_variable('cache', 'cached_key')
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape,
value.dtype)
cache_index = self.variable('cache', 'cache_index',
lambda: jnp.array(0, dtype=jnp.int32))
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable(
"cache", "cached_value", jnp.zeros, value.shape, value.dtype
)
cache_index = self.variable(
"cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized:
if self.transpose_batch_sequence:
length, batch, num_attention_heads, head_dim = cached_key.value.shape
......@@ -1140,8 +1202,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
# Sanity shape check of cached key against input query.
if expected_shape != query.shape:
raise ValueError(
'Autoregressive cache shape error, '
f"expected query shape {expected_shape} instead got {query.shape}.")
"Autoregressive cache shape error, "
f"expected query shape {expected_shape} instead got {query.shape}."
)
cur_index = cache_index.value
one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
......@@ -1153,21 +1216,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
cache_index.value = cache_index.value + 1
mask = combine_masks(
mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
)
if bias is not None:
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim,
in_axes=(None, 0, None, None))
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2)
dynamic_vector_slice_in_dim = vmap(
lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
)
bias = dynamic_vector_slice_in_dim(
jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
)
LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
if self.transpose_batch_sequence:
LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)
if qkv_layout == QKVLayout.BS3HD:
qkv_proj = qkv_proj.reshape(*qkv_proj.shape[:2], 3, self.num_attention_heads,
self.head_dim)
qkv_proj = qkv_proj.reshape(
*qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
)
qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
dpa_args = [qkv_proj, None, None]
......@@ -1191,7 +1258,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dpa_args = [query, key, value]
scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
x = DotProductAttention(head_dim=self.head_dim,
x = DotProductAttention(
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attn_mask_type=self.attn_mask_type,
......@@ -1202,14 +1270,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)(
*dpa_args, mask, bias, deterministic=deterministic)
transpose_batch_sequence=self.transpose_batch_sequence,
)(*dpa_args, mask, bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1],
out = DenseGeneral(
features=inputs_q.shape[-1],
transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1,
kernel_init=self.kernel_init,
......@@ -1221,8 +1290,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
name='out')(x)
out = checkpoint_name(out, 'out_proj')
name="out",
)(x)
out = checkpoint_name(out, "out_proj")
return out, ln_out
......@@ -1250,11 +1320,12 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
"""
num_buckets: int
max_distance: int
num_attention_heads: int
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
embedding_axes: Tuple[str, ...] = ('heads', 'relpos_buckets')
embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
dtype: DType = jnp.float32
@nn.compact
......@@ -1296,26 +1367,30 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met
rpb_max_exact = rpb_num_buckets // 2
rpb_is_small = negative_rp < rpb_max_exact
rpb_val_if_large = rpb_max_exact + (
np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) /
np.log(self.max_distance / rpb_max_exact) *
(rpb_num_buckets - rpb_max_exact)).astype(np.int32)
np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
/ np.log(self.max_distance / rpb_max_exact)
* (rpb_num_buckets - rpb_max_exact)
).astype(np.int32)
rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)
# Compute relative attention bias
relative_attention_bias = nn_partitioning.param_with_axes(
'rel_embedding',
self.embedding_init, (self.num_attention_heads, self.num_buckets),
"rel_embedding",
self.embedding_init,
(self.num_attention_heads, self.num_buckets),
jnp.float32,
axes=self.embedding_axes)
axes=self.embedding_axes,
)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
values = lax.dot_general(relative_attention_bias, rp_bucket_one_hot,
(((1,), (0,)), ((), ())))
values = lax.dot_general(
relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
)
return values[jnp.newaxis, ...]
......@@ -1330,6 +1405,7 @@ class TransformerLayerType(Enum):
DECODER:
Decoder type of TransformerLayer.
"""
ENCODER = "encoder"
DECODER = "decoder"
......@@ -1497,7 +1573,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: Optional[int] = None
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
hidden_dropout: float = 0.1
......@@ -1505,24 +1581,24 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = 'dropout'
dropout_rng_name: str = "dropout"
mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ('relu',)
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal'
self_attn_mask_type: str = "causal"
self_attn_bias_type: Optional[str] = None
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
......@@ -1535,23 +1611,26 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self):
if self.mha_kernel_init is None:
self.mha_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
self.mha_kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal')
self.mlp_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal"
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(self,
def __call__(
self,
inputs: Array,
encoded: Array = None,
attention_mask: Array = None,
encoder_decoder_mask: Array = None,
deterministic: bool = False,
decode: bool = False,
max_decode_length: bool = None):
max_decode_length: bool = None,
):
"""
Transformer Layer: attention block and a feedforward network (MLP)
......@@ -1585,17 +1664,18 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray
Output tensors.
"""
assert self.layer_type in TransformerLayerType, \
"layer_type should be one of TransformerLayerType" \
f", but got {self.layer_type}."
assert (
self.layer_type in TransformerLayerType
), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
assert self.hidden_size % self.num_attention_heads == 0, \
"hidden_size should be multiples of num_attention_heads" \
assert self.hidden_size % self.num_attention_heads == 0, (
"hidden_size should be multiples of num_attention_heads"
f", but got {self.hidden_size=} and {self.num_attention_heads=}."
)
assert self.layer_type == TransformerLayerType.DECODER or \
(self.layer_type == TransformerLayerType.ENCODER and decode is False), \
"decode should be False when layer_type == TransformerLayerType.ENCODER."
assert self.layer_type == TransformerLayerType.DECODER or (
self.layer_type == TransformerLayerType.ENCODER and decode is False
), "decode should be False when layer_type == TransformerLayerType.ENCODER."
head_dim = self.hidden_size // self.num_attention_heads
......@@ -1605,8 +1685,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
def generate_batch_seqlen_logical_axes(is_shared_seq=None):
axes = [None, None]
is_shared_seq = self.enable_sequence_parallel if is_shared_seq is None \
else is_shared_seq
is_shared_seq = (
self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
)
axes[batch_dim] = BATCH_AXES
axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
......@@ -1615,13 +1696,14 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
attn_bias = None
if self.enable_relative_embedding:
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
rel_emb = RelativePositionBiases(
num_buckets=32,
max_distance=128,
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name="relpos_bias",
)
else:
rel_emb = self.relative_embedding
......@@ -1639,12 +1721,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
# Make name be the exactly same as T5X, since names would affect
# RNGKey during init and apply. Myabe no need in the feature.
if self.layer_type == TransformerLayerType.ENCODER:
mha_name = 'attention'
mha_name = "attention"
else:
mha_name = 'self_attention'
mha_name = "self_attention"
inputs = with_sharding_constraint_by_logical_axes(
inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
# [batch, length, emb_dim] -> [batch, length, emb_dim]
residual = inputs
......@@ -1677,12 +1760,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
name=mha_name)(inputs,
inputs,
attention_mask,
attn_bias,
deterministic=deterministic,
decode=decode)
name=mha_name,
)(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
def hidden_dropout(x, deterministic):
assert isinstance(self.hidden_dropout_dims, Sequence)
......@@ -1690,21 +1769,27 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
for dims in self.hidden_dropout_dims:
assert -x_shape_len <= dims < x_shape_len
return nn.Dropout(rate=self.hidden_dropout,
return nn.Dropout(
rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims,
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
rng_collection=self.dropout_rng_name,
)(x, deterministic=deterministic)
x = with_sharding_constraint_by_logical_axes(
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
residual = with_sharding_constraint_by_logical_axes(
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
x = hidden_dropout(x, deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path,
x = nn.Dropout(
rate=self.drop_path,
broadcast_dims=drop_path_shape,
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
rng_collection=self.dropout_rng_name,
)(x, deterministic=deterministic)
if self.apply_residual_connection_post_layernorm:
assert ln_out is not None
......@@ -1714,11 +1799,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_input = x
if self.layer_type == TransformerLayerType.DECODER:
assert encoded is not None, \
"encoded is required when layer_type == TransformerLayerType.DECODER."
assert (
encoded is not None
), "encoded is required when layer_type == TransformerLayerType.DECODER."
x = with_sharding_constraint_by_logical_axes(
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
residual = x
y, ln_out = MultiHeadAttention(
......@@ -1735,8 +1822,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
zero_centered_gamma=self.zero_centered_gamma,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
input_layernorm=True, # Must do LayerNorm before MHA.
attn_mask_type='padding',
attn_bias_type='no_bias',
attn_mask_type="padding",
attn_bias_type="no_bias",
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
......@@ -1750,15 +1837,15 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
name='encoder_decoder_attention')(x,
encoded,
encoder_decoder_mask,
deterministic=deterministic)
name="encoder_decoder_attention",
)(x, encoded, encoder_decoder_mask, deterministic=deterministic)
y = with_sharding_constraint_by_logical_axes(
y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
residual = with_sharding_constraint_by_logical_axes(
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
y = hidden_dropout(y, deterministic)
......@@ -1769,7 +1856,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_input = y + residual
mlp_input = with_sharding_constraint_by_logical_axes(
mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)
......@@ -1802,7 +1890,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
name='mlp',
name="mlp",
)(mlp_input, deterministic=deterministic)
if self.apply_residual_connection_post_layernorm:
......@@ -1810,27 +1898,33 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
residual = ln_out
z = with_sharding_constraint_by_logical_axes(
z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
residual = with_sharding_constraint_by_logical_axes(
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
z = hidden_dropout(z, deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
z = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
z, deterministic=deterministic
)
z = z + residual
if self.output_layernorm:
z = with_sharding_constraint_by_logical_axes(
z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
z = LayerNorm(layernorm_type=self.layernorm_type,
z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
)
z = LayerNorm(
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
scale_axes=(W_NO_SHARD_AXES,),
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
name="output_layernorm")(z)
name="output_layernorm",
)(z)
return z
......@@ -135,8 +135,8 @@ class FP8MetaPackage:
@staticmethod
def update_fp8_scale(
amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray],
fp8_dtype_list: List[DType]) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray], fp8_dtype_list: List[DType]
) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
"""
Get update scale and scale_inv list
"""
......@@ -151,6 +151,7 @@ class FP8MetaPackage:
class AmaxComputeAlgo(Enum):
"""AmaxComputeAlgo."""
MAX = "max"
MOST_RECENT = "most_recent"
......@@ -162,6 +163,7 @@ class FP8Helper:
"""
FP8 helper to manage the FP8 meta
"""
INITIALIZED = False
MARGIN: float = 0.0
FP8_FORMAT: Format = Format.HYBRID
......@@ -184,18 +186,19 @@ class FP8Helper:
return FP8Helper.INITIALIZED
@staticmethod
def initialize(margin: float = 0.0,
def initialize(
margin: float = 0.0,
fp8_format: Format = Format.HYBRID,
amax_history_len: int = 1,
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None:
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX,
) -> None:
"""
Initialize the FP8 meta
"""
FP8Helper.INITIALIZED = True
FP8Helper.MARGIN = margin
FP8Helper.FP8_FORMAT = fp8_format
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.AMAX_HISTORY_LEN = amax_history_len
FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
FP8Helper.FP8_2X_ACC_FPROP = False
......@@ -210,8 +213,7 @@ class FP8Helper:
FP8Helper.INITIALIZED = False
FP8Helper.MARGIN = 0.0
FP8Helper.FP8_FORMAT = Format.HYBRID
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
......@@ -300,9 +302,11 @@ class FP8Helper:
@contextmanager
def fp8_autocast(enabled: bool = False,
def fp8_autocast(
enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
mesh_resource: Optional[MeshResource] = None) -> None:
mesh_resource: Optional[MeshResource] = None,
) -> None:
r"""
Context manager for FP8 usage.
......@@ -344,13 +348,18 @@ def fp8_autocast(enabled: bool = False,
fp8_recipe = DelayedScaling()
assert fp8_recipe.amax_compute_algo in [
"max", "most_recent"
], ("DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.")
assert fp8_recipe.scaling_factor_compute_algo is None, (
"DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.")
assert fp8_recipe.override_linear_precision == (False, False, False), (
"DelayedScaling override_linear_precision isn't supported by TE/JAX.")
assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
"max",
"most_recent",
], "DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
assert (
fp8_recipe.scaling_factor_compute_algo is None
), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert fp8_recipe.override_linear_precision == (
False,
False,
False,
), "DelayedScaling override_linear_precision isn't supported by TE/JAX."
assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."
if mesh_resource is None:
mesh_resource = MeshResource()
......@@ -362,13 +371,15 @@ def fp8_autocast(enabled: bool = False,
assert fp8_available, reason_for_no_fp8
amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
if fp8_recipe.amax_compute_algo == 'max':
if fp8_recipe.amax_compute_algo == "max":
amax_compute_algo = AmaxComputeAlgo.MAX
FP8Helper.initialize(margin=fp8_recipe.margin,
FP8Helper.initialize(
margin=fp8_recipe.margin,
fp8_format=fp8_recipe.fp8_format,
amax_history_len=fp8_recipe.amax_history_len,
amax_compute_algo=amax_compute_algo)
amax_compute_algo=amax_compute_algo,
)
yield
finally:
FP8Helper.finalize()
......@@ -410,9 +421,12 @@ def get_delayed_scaling():
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
else "most_recent"
return DelayedScaling(margin=int(FP8Helper.MARGIN),
amax_compute_algo = (
"max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
)
return DelayedScaling(
margin=int(FP8Helper.MARGIN),
fp8_format=FP8Helper.FP8_FORMAT,
amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo)
amax_compute_algo=amax_compute_algo,
)
......@@ -16,56 +16,55 @@ from .sharding import with_sharding_constraint_by_logical_axes
def canonicalize_layernorm_type(x):
'''
"""
Canonicalize the layernorm type
'''
canonicalized = x.lower().strip().replace('-', '').replace('_', '')
assert canonicalized in ['layernorm', 'rmsnorm']
"""
canonicalized = x.lower().strip().replace("-", "").replace("_", "")
assert canonicalized in ["layernorm", "rmsnorm"]
return canonicalized
def layernorm(inputs: jnp.ndarray,
def layernorm(
inputs: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6):
epsilon: float = 1e-6,
):
"""
LN/RMSNorm wrapper
Only support layernorm_type in ['layernorm', 'rmsnorm']
"""
output = _layernorm(inputs,
output = _layernorm(
inputs,
gamma,
beta,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
epsilon=epsilon,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _layernorm(x,
gamma,
beta,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6):
def _layernorm(
x, gamma, beta, layernorm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6
):
output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon)
return output
def _layernorm_fwd_rule(x,
gamma,
beta,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6):
def _layernorm_fwd_rule(
x, gamma, beta, layernorm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6
):
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'layernorm':
if layernorm_type == "layernorm":
output, mu, rsigma = tex.layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
elif layernorm_type == 'rmsnorm':
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
elif layernorm_type == "rmsnorm":
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
output, rsigma = tex.rmsnorm_fwd(x, gamma, epsilon)
mu = None
else:
......@@ -75,17 +74,14 @@ def _layernorm_fwd_rule(x,
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
x, mu, rsigma, gamma = ctx
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = tex.layernorm_bwd(dz,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
elif layernorm_type == 'rmsnorm':
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
elif layernorm_type == "rmsnorm":
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx, dgamma = tex.rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
else:
......@@ -107,9 +103,11 @@ def layernorm_fp8_dot(
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[
str, ...] = None, # The logic axes of sharding constraint to the layernorm input.
dot_input_axes: Tuple[str,
...] = None # The logic axes of sharding constraint to the dot input.
str, ...
] = None, # The logic axes of sharding constraint to the layernorm input.
dot_input_axes: Tuple[
str, ...
] = None, # The logic axes of sharding constraint to the dot input.
) -> jnp.ndarray:
"""
Layernorm + FP8 GEMM
......@@ -118,22 +116,55 @@ def layernorm_fp8_dot(
scale_list = fp8_meta_pkg.scale_list
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
output = _layernorm_fp8_dot(x, kernel, gamma, beta, amax_list, scale_list, layernorm_type,
fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_input_axes)
output = _layernorm_fp8_dot(
x,
kernel,
gamma,
beta,
amax_list,
scale_list,
layernorm_type,
fwd_dtype,
bwd_dtype,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray],
layernorm_type: str, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...]):
output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, amax_list, scale_list,
layernorm_type, fwd_dtype, bwd_dtype,
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_input_axes)
def _layernorm_fp8_dot(
x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
amax_list: List[jnp.ndarray],
scale_list: List[jnp.ndarray],
layernorm_type: str,
fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype,
zero_centered_gamma: bool,
epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
):
output, _ = _layernorm_fp8_dot_fwd_rule(
x,
kernel,
gamma,
beta,
amax_list,
scale_list,
layernorm_type,
fwd_dtype,
bwd_dtype,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
)
return output
......@@ -150,20 +181,23 @@ def _layernorm_fp8_dot_fwd_rule(
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes):
dot_input_axes,
):
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list, *scale_list)
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair(
*amax_list, *scale_list
)
amax_list = maybe_fm32_to_fp32(*amax_list)
scale_list = maybe_fm32_to_fp32(*scale_list)
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(amax_list, scale_list,
fp8_dtype_list)
scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(
amax_list, scale_list, fp8_dtype_list
)
amax_list = FP8MetaPackage.update_amax_list(amax_list)
x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1]
......@@ -172,7 +206,7 @@ def _layernorm_fp8_dot_fwd_rule(
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
if layernorm_type == "layernorm":
ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
x,
gamma,
......@@ -182,17 +216,15 @@ def _layernorm_fp8_dot_fwd_rule(
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
epsilon=epsilon,
)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(x,
gamma,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
epsilon=epsilon)
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(
x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon
)
mu = None
assert x.shape == ln_out.shape
......@@ -204,19 +236,41 @@ def _layernorm_fp8_dot_fwd_rule(
# Kernel in (hidden_in, hidden_out...)
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel, updated_kernel_amax = \
tex.cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
casted_kernel, updated_kernel_amax = tex.cast_fp8(
kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype
)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)
# (batch..., hidden_in) x (hidden_in, hidden_out...)
output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
output = fp8_dot_impl(
ln_out,
casted_kernel,
x_scale_inv,
kernel_scale_inv,
x.dtype,
(x_contracting_dims, k_contracting_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
)
ctx = (ln_out, casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
k_contracting_dims, maybe_fp32_to_fm32)
ctx = (
ln_out,
casted_kernel,
amax_list,
scale_list,
scale_inv_list,
updated_x_amax,
updated_kernel_amax,
x.shape,
kernel.shape,
mu,
rsigma,
x,
gamma,
x_contracting_dims,
k_contracting_dims,
maybe_fp32_to_fm32,
)
return output, ctx
......@@ -230,11 +284,26 @@ def _layernorm_fp8_dot_bwd_rule(
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
ctx,
grad):
ln_out_, casted_kernel, amax_list, scale_list, scale_inv_list, \
updated_x_amax, updated_kernel_amax, \
x_shape, kernel_shape, mu, rsigma, x, gamma, \
x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32 = ctx
grad,
):
(
ln_out_,
casted_kernel,
amax_list,
scale_list,
scale_inv_list,
updated_x_amax,
updated_kernel_amax,
x_shape,
kernel_shape,
mu,
rsigma,
x,
gamma,
x_contracting_dims,
k_contracting_dims,
maybe_fp32_to_fm32,
) = ctx
ln_out_t = tex.transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
......@@ -242,53 +311,70 @@ def _layernorm_fp8_dot_bwd_rule(
grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
casted_grad, casted_grad_t, updated_grad_amax = \
tex.cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims))
casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
grad,
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=min(x_contracting_dims),
)
xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
wgrad = fp8_dot_impl(
ln_out_t,
casted_grad_t,
x_scale_inv,
grad_scale_inv,
grad.dtype,
(xt_constracting_dim, gt_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
)
g_for_dgrad_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim)
)
k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
dgrad = fp8_dot_impl(
casted_grad,
casted_kernel,
grad_scale_inv,
kernel_scale_inv,
grad.dtype,
(g_for_dgrad_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = tex.layernorm_bwd(dgrad,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd(
dgrad, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx, dgamma = tex.rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax_list[FP8MetaPackage.INPUT_IDX] = \
amax_list[FP8MetaPackage.INPUT_IDX] = (
amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
amax_list[FP8MetaPackage.WEIGHT_IDX] = \
)
amax_list[FP8MetaPackage.WEIGHT_IDX] = (
amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0])
amax_list[FP8MetaPackage.GRAD_IDX] = \
)
amax_list[FP8MetaPackage.GRAD_IDX] = (
amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
)
amax_list = maybe_fp32_to_fm32(*amax_list)
scale_list = maybe_fp32_to_fm32(*scale_list)
return dx, wgrad, \
dgamma, dbeta, \
amax_list, scale_list
return dx, wgrad, dgamma, dbeta, amax_list, scale_list
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)
......@@ -41,7 +41,7 @@ def _activation_lu_fwd_rule(x, activation_type):
def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx
(x,) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
......@@ -52,7 +52,8 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
def fused_layernorm_fp8_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
......@@ -64,10 +65,11 @@ def fused_layernorm_fp8_mlp(x: jnp.ndarray,
layernorm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = 'ffn1',
ffn2_ckpt_name: str = 'ffn2',
activation_type: Sequence[Union[str, Callable]] = ('gelu',),
use_bias: bool = True) -> jnp.ndarray:
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True,
) -> jnp.ndarray:
"""
Layernorm + GEMM1 + bias + activation + GEMM2 + bias
"""
......@@ -88,36 +90,91 @@ def fused_layernorm_fp8_mlp(x: jnp.ndarray,
bwd_dtype = FP8Helper.BWD_DTYPE
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
if layernorm_type == "rmsnorm":
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2,
amax_list_1, amax_list_2, scale_list_1, scale_list_2,
fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma,
epsilon, layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
output = _fused_layernorm_fp8_mlp(
x,
gamma,
beta,
kernel_1,
kernel_2,
bias_1,
bias_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
fwd_dtype,
bwd_dtype,
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
use_bias,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
bias_2: jnp.ndarray, amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray], scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray], fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
epsilon: float, layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], use_bias: bool):
def _fused_layernorm_fp8_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray,
bias_1: jnp.ndarray,
bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype,
layernorm_type: str,
zero_centered_gamma: bool,
epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool,
):
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, amax_list_1, amax_list_2, scale_list_1,
scale_list_2, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
x,
gamma,
beta,
kernel_1,
kernel_2,
bias_1,
bias_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
fwd_dtype,
bwd_dtype,
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
use_bias,
)
return output
......@@ -144,7 +201,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
use_bias):
use_bias,
):
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
......@@ -159,20 +217,22 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0]
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list_1, *scale_list_1,
*amax_list_2, *scale_list_2)
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair(
*amax_list_1, *scale_list_1, *amax_list_2, *scale_list_2
)
amax_list_1 = maybe_fm32_to_fp32(*amax_list_1)
scale_list_1 = maybe_fm32_to_fp32(*scale_list_1)
amax_list_2 = maybe_fm32_to_fp32(*amax_list_2)
scale_list_2 = maybe_fm32_to_fp32(*scale_list_2)
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale(amax_list_1, scale_list_1,
fp8_dtype_list)
scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale(
amax_list_1, scale_list_1, fp8_dtype_list
)
amax_list_1 = FP8MetaPackage.update_amax_list(amax_list_1)
scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(amax_list_2, scale_list_2,
fp8_dtype_list)
scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(
amax_list_2, scale_list_2, fp8_dtype_list
)
amax_list_2 = FP8MetaPackage.update_amax_list(amax_list_2)
x_amax = amax_list_1[FP8MetaPackage.INPUT_IDX][0:1]
......@@ -181,7 +241,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
if layernorm_type == "layernorm":
ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
x,
gamma,
......@@ -191,17 +251,15 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
epsilon=epsilon,
)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(x,
gamma,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
epsilon=epsilon)
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(
x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon
)
mu = None
assert x.shape == ln_out.shape
......@@ -212,15 +270,22 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_1, updated_kernel_1_amax = \
tex.cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
casted_kernel_1, updated_kernel_1_amax = tex.cast_fp8(
kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype
)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
dot_1_output = fp8_dot_impl(
ln_out,
casted_kernel_1,
x_scale_inv,
kernel_1_scale_inv,
x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
)
if use_bias:
bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
......@@ -234,12 +299,18 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_lu_out_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
# (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out, updated_activation_lu_amax = \
tex.act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype, activation_type)
casted_activation_lu_out, updated_activation_lu_amax = tex.act_lu_fp8(
dot_1_output,
activation_lu_out_amax,
activation_lu_out_scale,
activation_lu_out_scale_inv,
fwd_dtype,
activation_type,
)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
casted_activation_lu_out, dot_2_input_axes)
casted_activation_lu_out, dot_2_input_axes
)
kernel_2_scale = scale_list_2[FP8MetaPackage.WEIGHT_IDX]
kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
......@@ -248,10 +319,15 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype,
dot_2_output = fp8_dot_impl(
casted_activation_lu_out,
casted_kernel_2,
activation_lu_out_scale_inv,
kernel_2_scale_inv,
x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
)
if use_bias:
bias_2_shape = bias_2.shape
......@@ -262,11 +338,32 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, scale_inv_list_1,
scale_inv_list_2, updated_x_amax, updated_activation_lu_amax, updated_kernel_1_amax,
updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape,
maybe_fp32_to_fm32)
ctx = (
x,
ln_out,
mu,
rsigma,
gamma,
dot_1_output,
casted_activation_lu_out,
casted_kernel_1,
casted_kernel_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
scale_inv_list_1,
scale_inv_list_2,
updated_x_amax,
updated_activation_lu_amax,
updated_kernel_1_amax,
updated_kernel_2_amax,
x_contracting_dims,
xt_batch_dims,
bias_1_shape,
bias_2_shape,
maybe_fp32_to_fm32,
)
return dot_2_output, ctx
......@@ -285,12 +382,34 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
activation_type,
use_bias,
ctx,
grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
casted_kernel_1, casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, \
scale_inv_list_1, scale_inv_list_2, updated_x_amax, \
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
grad,
):
(
x,
ln_out,
mu,
rsigma,
gamma,
dot_1_output,
casted_activation_lu_out,
casted_kernel_1,
casted_kernel_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
scale_inv_list_1,
scale_inv_list_2,
updated_x_amax,
updated_activation_lu_amax,
updated_kernel_1_amax,
updated_kernel_2_amax,
x_contracting_dims,
xt_batch_dims,
bias_1_shape,
bias_2_shape,
maybe_fp32_to_fm32,
) = ctx
grad_amax = amax_list_2[FP8MetaPackage.GRAD_IDX][0:1]
grad_scale = scale_list_2[FP8MetaPackage.GRAD_IDX]
......@@ -299,35 +418,55 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
if use_bias:
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
tex.dbias_cast_transpose(grad, grad_amax, grad_scale,
grad_scale_inv, bwd_dtype,
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = tex.dbias_cast_transpose(
grad,
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
transpose_axis_boundary=-1,
)
dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
else:
casted_grad, casted_grad_t, updated_grad_amax = \
tex.cast_transpose(grad, grad_amax, grad_scale,
grad_scale_inv, bwd_dtype,
casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
grad,
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
transpose_axis_boundary=-1,
)
dbias_2 = None
casted_activation_lu_out_t = tex.transpose(casted_activation_lu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
casted_activation_lu_out_t = tex.transpose(
casted_activation_lu_out, static_axis_boundary=-1, transpose_axis_boundary=-1
)
# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
wgrad_2 = fp8_dot_impl(
casted_activation_lu_out_t,
casted_grad_t,
gemm2_x_scale_inv,
grad_scale_inv,
grad.dtype,
(xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
)
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_2 = fp8_dot_impl(
casted_grad,
casted_kernel_2,
grad_scale_inv,
kernel_2_scale_inv,
grad.dtype,
(x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
......@@ -338,7 +477,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
if len(activation_type) > 1: # if gated
if use_bias:
dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
tex.dbias_cast_transpose(
dactivation_lu,
dactivation_lu_amax,
......@@ -346,10 +485,12 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2)
transpose_axis_boundary=-2,
)
)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
tex.dgated_act_lu_cast_transpose(
dgrad_2,
dot_1_output,
......@@ -358,11 +499,13 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
activation_type=activation_type)
activation_type=activation_type,
)
)
dbias_1 = None
else:
if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
tex.dact_lu_dbias_cast_transpose(
dgrad_2,
dot_1_output,
......@@ -372,11 +515,13 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2,
activation_type=activation_type)
activation_type=activation_type,
)
)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
tex.cast_transpose(
dactivation_lu,
dactivation_lu_amax,
......@@ -384,7 +529,9 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2)
transpose_axis_boundary=-2,
)
)
dbias_1 = None
ln_out_t = tex.transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
......@@ -392,54 +539,83 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
(1, 2))
wgrad_1 = fp8_dot_impl(
ln_out_t,
casted_dactivation_lu_t,
gemm1_x_scale_inv,
dactivation_lu_scale_inv,
grad.dtype,
(xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
)
x_contracting_dims = (
(min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
(1, 2),
)
kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
kernel_1_scale_inv, grad.dtype, x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = fp8_dot_impl(
casted_dactivation_lu,
casted_kernel_1,
dactivation_lu_scale_inv,
kernel_1_scale_inv,
grad.dtype,
x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = tex.layernorm_bwd(dgrad_1,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd(
dgrad_1, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx, dgamma = tex.rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax_list_1[FP8MetaPackage.INPUT_IDX] = \
amax_list_1[FP8MetaPackage.INPUT_IDX] = (
amax_list_1[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
amax_list_1[FP8MetaPackage.WEIGHT_IDX] = \
)
amax_list_1[FP8MetaPackage.WEIGHT_IDX] = (
amax_list_1[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_1_amax[0])
amax_list_1[FP8MetaPackage.GRAD_IDX] = \
)
amax_list_1[FP8MetaPackage.GRAD_IDX] = (
amax_list_1[FP8MetaPackage.GRAD_IDX].at[0].set(updated_dactivation_lu_amax[0])
amax_list_2[FP8MetaPackage.INPUT_IDX] = \
)
amax_list_2[FP8MetaPackage.INPUT_IDX] = (
amax_list_2[FP8MetaPackage.INPUT_IDX].at[0].set(updated_activation_lu_amax[0])
amax_list_2[FP8MetaPackage.WEIGHT_IDX] = \
)
amax_list_2[FP8MetaPackage.WEIGHT_IDX] = (
amax_list_2[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_2_amax)
amax_list_2[FP8MetaPackage.GRAD_IDX] = \
)
amax_list_2[FP8MetaPackage.GRAD_IDX] = (
amax_list_2[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
)
amax_list_1 = maybe_fp32_to_fm32(*amax_list_1)
scale_list_1 = maybe_fp32_to_fm32(*scale_list_1)
amax_list_2 = maybe_fp32_to_fm32(*amax_list_2)
scale_list_2 = maybe_fp32_to_fm32(*scale_list_2)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
amax_list_1, amax_list_2, scale_list_1, scale_list_2
return (
dx,
dgamma,
dbeta,
wgrad_1,
wgrad_2,
dbias_1,
dbias_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
)
_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
_fused_layernorm_fp8_mlp_bwd_rule)
_fused_layernorm_fp8_mlp.defvjp(
_fused_layernorm_fp8_mlp_fwd_rule, _fused_layernorm_fp8_mlp_bwd_rule
)
......@@ -49,17 +49,19 @@ class TransformerEngineBaseLayer(BaseLayer):
FP8Helper.FP8_COLLECTION_NAME: [
WeightHParamsCollection.SKIP_LP_REGULARIZATION,
WeightHParamsCollection.OVERWRITE_WITH_GRADIENT,
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION,
]
}
flax_module_p = pax_fiddle.Config(flax_adapter.FlaxModuleAdapter,
flax_module_p = pax_fiddle.Config(
flax_adapter.FlaxModuleAdapter,
module_factory_method=flax_module_cls,
logical_axes_rules=self.logical_axes_rules,
var_collection_map=fp8_collection_map,
ici_mesh_shape=self.ici_mesh_shape,
dcn_mesh_shape=self.dcn_mesh_shape,
mesh_axis_names=self.mesh_axis_names)
mesh_axis_names=self.mesh_axis_names,
)
self.create_child(name, flax_module_p.clone())
......@@ -68,7 +70,7 @@ class LayerNorm(TransformerEngineBaseLayer):
"""LayerNorm"""
epsilon: float = 1e-6
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
......@@ -80,17 +82,18 @@ class LayerNorm(TransformerEngineBaseLayer):
"""setup"""
super().setup()
ln_cls = partial(flax_LayerNorm,
ln_cls = partial(
flax_LayerNorm,
epsilon=self.epsilon,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.bias_init),
bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", self.bias_init),
bias_axes=self.bias_axes,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence)
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("layer_norm", ln_cls)
......@@ -109,9 +112,9 @@ class FusedSoftmax(TransformerEngineBaseLayer):
"""setup"""
super().setup()
fused_softmax_cls = partial(Softmax,
scale_factor=self.scale_factor,
softmax_type=self.softmax_type)
fused_softmax_cls = partial(
Softmax, scale_factor=self.scale_factor, softmax_type=self.softmax_type
)
self.create_layer("fused_softmax", fused_softmax_cls)
......@@ -151,7 +154,8 @@ class Linear(TransformerEngineBaseLayer):
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence)
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("linear", dense_general_cls)
......@@ -165,7 +169,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
out_features: int = 512
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
......@@ -198,7 +202,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init),
"ln_bias", self.ln_bias_init
),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
......@@ -212,7 +217,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
depth_scaling=self.depth_scaling)
depth_scaling=self.depth_scaling,
)
self.create_layer("ln_linear", ln_dense_general_cls)
......@@ -226,7 +232,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
intermediate_dim: int = 2048
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
......@@ -243,7 +249,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
activations: Sequence[Union[str, Callable]] = ("relu",)
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1
......@@ -263,7 +269,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init),
"ln_bias", self.ln_bias_init
),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes_1=self.kernel_axes_1,
......@@ -281,7 +288,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence)
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("ln_mlp", ln_mlp_cls)
......
......@@ -35,7 +35,7 @@ class RelativePositionBiases(TransformerEngineBaseLayer):
"""generate_embedding_init"""
embedding_init = init
if embedding_init is None:
rb_stddev = (num_attention_heads * num_buckets)**-0.5
rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
return embedding_init
......@@ -44,16 +44,20 @@ class RelativePositionBiases(TransformerEngineBaseLayer):
super().setup()
embedding_init = RelativePositionBiases.generate_embedding_init(
self.embedding_init, self.num_attention_heads, self.num_buckets)
self.embedding_init, self.num_attention_heads, self.num_buckets
)
rpb_cls = partial(flax_RelativePositionBiases,
rpb_cls = partial(
flax_RelativePositionBiases,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
num_attention_heads=self.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init),
"rel_embedding", embedding_init
),
embedding_axes=self.embedding_axes,
dtype=self.dtype)
dtype=self.dtype,
)
self.create_layer("relative_position_bias", rpb_cls)
......@@ -68,12 +72,12 @@ class DotProductAttention(TransformerEngineBaseLayer):
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = 'causal'
attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None
dropout_rng_name: str = 'dropout'
dropout_rng_name: str = "dropout"
float32_logits: bool = False
qkv_layout: str = 'bshd_bshd_bshd'
qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
......@@ -81,10 +85,11 @@ class DotProductAttention(TransformerEngineBaseLayer):
"""setup"""
super().setup()
assert self.head_dim > 0, f'{self.head_dim=}'
assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'
assert self.head_dim > 0, f"{self.head_dim=}"
assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
dpa_cls = partial(flax_DotProductAttention,
dpa_cls = partial(
flax_DotProductAttention,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
......@@ -96,25 +101,25 @@ class DotProductAttention(TransformerEngineBaseLayer):
float32_logits=self.float32_logits,
qkv_layout=self.qkv_layout,
scale_factor=self.scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("dot_product_attention", dpa_cls)
def __call__(self,
def __call__(
self,
query: JTensor,
key: JTensor,
value: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
deterministic: bool = False) -> JTensor:
deterministic: bool = False,
) -> JTensor:
"""__call__"""
return self.dot_product_attention(query,
key,
value,
mask,
bias,
deterministic=deterministic)
return self.dot_product_attention(
query, key, value, mask, bias, deterministic=deterministic
)
class MultiHeadAttention(TransformerEngineBaseLayer):
......@@ -123,8 +128,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
dropout_rng_name: str = 'dropout'
attention_dropout: float = 0.0
dropout_rng_name: str = "dropout"
input_layernorm: bool = True
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
......@@ -132,12 +137,12 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
attn_mask_type: str = 'causal'
attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
fuse_qkv_params: bool = True
......@@ -160,24 +165,32 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
f"Please uses {__class__}.num_attention_heads as the new API.",
DeprecationWarning,
)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
f"Please use {__class__}.attention_dropout as the new API.",
DeprecationWarning,
)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning)
DeprecationWarning,
)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
f"Please use {__class__}.fuse_qkv_params as the new API.",
DeprecationWarning,
)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
......@@ -187,8 +200,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
"""setup"""
super().setup()
assert self.head_dim > 0, f'{self.head_dim=}'
assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'
assert self.head_dim > 0, f"{self.head_dim=}"
assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
mha_cls = partial(
flax_MultiHeadAttention,
......@@ -219,25 +232,25 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits)
float32_logits=self.float32_logits,
)
self.create_layer("multi_head_attn", mha_cls)
def __call__(self,
def __call__(
self,
inputs_q: JTensor,
inputs_kv: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
decode: bool = False,
deterministic: bool = False) -> JTensor:
deterministic: bool = False,
) -> JTensor:
"""__call__"""
return self.multi_head_attn(inputs_q,
inputs_kv,
mask,
bias,
decode=decode,
deterministic=deterministic)
return self.multi_head_attn(
inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic
)
class TransformerLayer(TransformerEngineBaseLayer):
......@@ -247,7 +260,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: Optional[int] = None
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
hidden_dropout: float = 0.1
......@@ -255,20 +268,20 @@ class TransformerLayer(TransformerEngineBaseLayer):
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = 'dropout'
mlp_activations: Sequence[str] = ('relu',)
dropout_rng_name: str = "dropout"
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal'
self_attn_mask_type: str = "causal"
self_attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
enable_relative_embedding: bool = True
......@@ -291,23 +304,27 @@ class TransformerLayer(TransformerEngineBaseLayer):
relative_embedding_flax_module = None
if self.enable_relative_embedding and self.relative_embedding is not None:
assert self.relative_embedding.num_attention_heads == \
self.num_attention_heads, \
"TransformerLayer.relative_embedding.num_attention_heads shoule be" \
assert self.relative_embedding.num_attention_heads == self.num_attention_heads, (
"TransformerLayer.relative_embedding.num_attention_heads shoule be"
"the same as TransformerLayer.num_attention_heads."
)
embedding_init = RelativePositionBiases.generate_embedding_init(
self.relative_embedding.embedding_init, self.relative_embedding.num_attention_heads,
self.relative_embedding.num_buckets)
self.relative_embedding.embedding_init,
self.relative_embedding.num_attention_heads,
self.relative_embedding.num_buckets,
)
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=self.relative_embedding.num_buckets,
max_distance=self.relative_embedding.max_distance,
num_attention_heads=self.relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init),
"rel_embedding", embedding_init
),
embedding_axes=self.relative_embedding.embedding_axes,
dtype=self.relative_embedding.dtype)
dtype=self.relative_embedding.dtype,
)
transformerlayer_cls = partial(
flax_TransformerLayer,
......@@ -326,9 +343,11 @@ class TransformerLayer(TransformerEngineBaseLayer):
intermediate_dropout_dims=self.intermediate_dropout_dims,
dropout_rng_name=self.dropout_rng_name,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", self.params_init),
"mha_kernel", self.params_init
),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", self.params_init),
"mlp_kernel", self.params_init
),
mlp_activations=self.mlp_activations,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
......@@ -351,18 +370,28 @@ class TransformerLayer(TransformerEngineBaseLayer):
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init)
scaled_query_init=self.scaled_query_init,
)
self.create_layer("transformerlayer", transformerlayer_cls)
def __call__(self,
def __call__(
self,
inputs: JTensor,
encoded: JTensor = None,
attention_mask: JTensor = None,
encoder_decoder_mask: JTensor = None,
deterministic: bool = False,
decode: bool = False,
max_decode_length: bool = None) -> JTensor:
max_decode_length: bool = None,
) -> JTensor:
"""__call__"""
return self.transformerlayer(inputs, encoded, attention_mask, encoder_decoder_mask,
deterministic, decode, max_decode_length)
return self.transformerlayer(
inputs,
encoded,
attention_mask,
encoder_decoder_mask,
deterministic,
decode,
max_decode_length,
)
......@@ -30,7 +30,7 @@ from build_tools.utils import package_files, copy_common_headers, install_and_im
from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension
install_and_import('pybind11')
install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension
CMakeBuildExtension = get_build_ext(BuildExtension)
......@@ -39,12 +39,12 @@ CMakeBuildExtension = get_build_ext(BuildExtension)
if __name__ == "__main__":
# Extensions
common_headers_dir = "common_headers"
copy_common_headers(
current_file_path.parent,
str(current_file_path / common_headers_dir))
copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir))
ext_modules = [
setup_jax_extension(
"csrc", current_file_path / "csrc", current_file_path / common_headers_dir)]
"csrc", current_file_path / "csrc", current_file_path / common_headers_dir
)
]
# Configure package
setuptools.setup(
......@@ -57,9 +57,11 @@ if __name__ == "__main__":
install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy", "praxis"],
include_package_data=True,
package_data={"csrc": package_files("csrc"),
package_data={
"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools")},
"build_tools": package_files("build_tools"),
},
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
......@@ -17,23 +17,22 @@ from jax.sharding import PartitionSpec
_PXLA_THREAD_RESOURCES = pxla.thread_resources
# Axis Names
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
SEQLEN_TP_AXES = 'nvte_seqlen_tp'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'
BATCH_AXES = "nvte_batch"
SEQLEN_AXES = "nvte_seqlen"
SEQLEN_TP_AXES = "nvte_seqlen_tp"
HEAD_AXES = "nvte_head"
HIDDEN_AXES = "nvte_hidden"
HIDDEN_TP_AXES = "nvte_hidden_tp"
JOINED_AXES = "nvte_joined"
W_NO_SHARD_AXES = "nvte_w_no_shard"
W_FSDP_AXES = "nvte_w_fsdp"
W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined"
def _get_mesh_info(resource: str):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
assert resource in mesh.axis_names, \
f"{resource} is not in the axis_names of Mesh {mesh}."
assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
return mesh.shape[resource], resource
......@@ -45,8 +44,11 @@ def get_sharding_map_logic_axis_to_mesh_axis():
IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))
batch_resources = [gsr.fsdp_resource, gsr.dp_resource] if IS_FSDP_OUTER \
batch_resources = (
[gsr.fsdp_resource, gsr.dp_resource]
if IS_FSDP_OUTER
else [gsr.dp_resource, gsr.fsdp_resource]
)
batch_dim_rule = []
for resource in batch_resources:
......@@ -168,6 +170,7 @@ class MeshResource:
The axis name in Mesh used to split model layers. along.
If it is None, then pipeline parallelism is disabled.
"""
dp_resource: str = None
tp_resource: str = None
fsdp_resource: str = None
......@@ -240,6 +243,7 @@ class MajorShardingType(Enum):
DPTP:
Data and Standard tensor parallel training.
"""
SINGLE = 0
DP = 1
TP = 2
......@@ -267,6 +271,7 @@ class ShardingType(Enum):
DP_TP_ROW:
Sharding along data and row-split tensor parallelism.
"""
SINGLE = (MajorShardingType.SINGLE, "single")
DP = (MajorShardingType.DP, "dp")
TP_COL = (MajorShardingType.TP, "tp_col")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment