Unverified Commit eed4dfc6 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support FP8 training for Pipeline Parallelism when Micro-batch > 1 on Paxml. (#774)



* Support FP8 Meta Dtype (FM32) and Align FP8 Scale Update with PyTorch.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Modify with the feedback of code review
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Hiding FlaxFloatMeta32 inside fp8.py
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Make functions to be JAX tracable objects.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Rebased with mian.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update jax images for github workflow.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 7acb5e2b
......@@ -31,7 +31,7 @@ jobs:
name: 'JAX'
runs-on: ubuntu-latest
container:
image: ghcr.io/nvidia/jax:latest
image: ghcr.io/nvidia/jax:jax
options: --user root
steps:
- name: 'Checkout'
......
......@@ -103,12 +103,18 @@ def _fp8_dot_fwd_rule(
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims):
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
x_shape_suf = x.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
assert x_shape_suf == kernel_shape_pre
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax)
gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
......@@ -130,7 +136,7 @@ def _fp8_dot_fwd_rule(
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
ctx = (casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape)
updated_kernel_amax, x.shape, kernel.shape, maybe_fp32_to_fm32)
return output, ctx
......@@ -138,7 +144,8 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, \
updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx
updated_x_amax, updated_kernel_amax, x_shape, kernel_shape, \
maybe_fp32_to_fm32 = ctx
gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
......@@ -170,7 +177,7 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax)
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
return dgrad, wgrad, fp8_max, amax, scale, scale_inv
......
......@@ -11,6 +11,7 @@ from typing import Dict, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen import fp8_ops
from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
......@@ -67,6 +68,15 @@ def _format2dtypes(format_: Format):
return jnp.bfloat16, jnp.bfloat16
# fm32 is a custom dtype to specify the "add" rules as max operation.
# This is typically used in Pipeline Parallelism + "MiconBatching > 1",
# which is implemented via nn.scan. Without this custom dtype, nn.scan
# would sum gradients from all micro-batches, and this is not the expected
# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should
# be "MAX".
FlaxFloatMeta32 = fp8_ops.fm32
class FP8MetaPackage:
"""
A container that contains all required meta data for FP8
......@@ -303,6 +313,42 @@ class FP8Helper:
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
@staticmethod
def generate_fp8_meta_dtype_converter_pair(*args):
"""
Generate a pair of conversion fun in-between fm32 and fp32.
"""
def identical_fun(*metas):
return metas
def fm32_to_fp32_fun(*metas):
for meta in metas:
assert meta.dtype == FlaxFloatMeta32
return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas]
def fp32_to_fm32_fun(*metas):
for meta in metas:
assert meta.dtype == jnp.float32
return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas]
# Make functions to be a vaild JAX type
partial_identical_fun = jax.tree_util.Partial(identical_fun)
partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun)
partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun)
if len(args) < 1:
return partial_identical_fun, partial_identical_fun
original_dtype = args[0].dtype
for arg in args:
assert arg.dtype == original_dtype
if original_dtype == FlaxFloatMeta32:
return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun
return partial_identical_fun, partial_identical_fun
@staticmethod
def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
"""
......
......@@ -162,6 +162,11 @@ def _layernorm_fp8_dot_fwd_rule(
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax)
gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
......@@ -216,7 +221,7 @@ def _layernorm_fp8_dot_fwd_rule(
ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
k_contracting_dims)
k_contracting_dims, maybe_fp32_to_fm32)
return output, ctx
......@@ -234,7 +239,7 @@ def _layernorm_fp8_dot_bwd_rule(
ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
updated_x_amax, updated_kernel_amax, \
x_shape, kernel_shape, mu, rsigma, x, gamma, \
x_contracting_dims, k_contracting_dims = ctx
x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32 = ctx
ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
......@@ -282,7 +287,7 @@ def _layernorm_fp8_dot_bwd_rule(
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
return dx, wgrad, \
dgamma, dbeta, \
......
......@@ -26,27 +26,42 @@ from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
activation_dict = {
('gelu',): {'fwd': gelu,
"bwd": dgelu},
('gelu', 'linear'): {'fwd': gated_gelu,
'bwd': dgated_gelu},
('silu',): {'fwd': silu,
"bwd": dsilu },
('silu', 'linear'): {'fwd': gated_silu,
'bwd': dgated_silu}
('gelu',): {
'fwd': gelu,
"bwd": dgelu
},
('gelu', 'linear'): {
'fwd': gated_gelu,
'bwd': dgated_gelu
},
('silu',): {
'fwd': silu,
"bwd": dsilu
},
('silu', 'linear'): {
'fwd': gated_silu,
'bwd': dgated_silu
}
}
activation_fp8_dict = {
('gelu',): {'fwd': gelu_fp8,
'bwd': dgelu_dbias_cast_transpose},
('gelu', 'linear'): {'fwd': gated_gelu_fp8,
'bwd': dgated_gelu_cast_transpose},
('silu',): { 'fwd': silu_fp8,
'bwd': dsilu_dbias_cast_transpose },
('silu', 'linear'): { 'fwd': gated_silu_fp8,
'bwd': dgated_silu_cast_transpose }
('gelu',): {
'fwd': gelu_fp8,
'bwd': dgelu_dbias_cast_transpose
},
('gelu', 'linear'): {
'fwd': gated_gelu_fp8,
'bwd': dgated_gelu_cast_transpose
},
('silu',): {
'fwd': silu_fp8,
'bwd': dsilu_dbias_cast_transpose
},
('silu', 'linear'): {
'fwd': gated_silu_fp8,
'bwd': dgated_silu_cast_transpose
}
}
......@@ -59,6 +74,7 @@ def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]
output = _activation_lu(x, activation_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
......@@ -66,10 +82,12 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable
return _output
def _activation_lu_fwd_rule(x, activation_type):
fwd_output = activation_dict[activation_type]["fwd"](x)
return fwd_output, (x,)
def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx
assert x.dtype == g.dtype
......@@ -78,11 +96,12 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
dx = jnp.reshape(dx, x.shape)
return (dx,)
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, fwd_dtype:jnp.dtype, bwd_dtype: jnp.dtype,
def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]):
"""
Activation Unit
......@@ -91,39 +110,47 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax,
scale, scale_inv, fwd_dtype, bwd_dtype, activation_type)
output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv, fwd_dtype,
bwd_dtype, activation_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(6,7,8))
def _activation_lu_fp8(x: jnp.ndarray,
dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]):
output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax,
scale, scale_inv, fwd_dtype, bwd_dtype,
activation_type)
output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv,
fwd_dtype, bwd_dtype, activation_type)
return output
def _activation_lu_fp8_fwd_rule(x,
def _activation_lu_fp8_fwd_rule(
x,
dx_trans_no_use, # pylint: disable=unused-argument
dbias_no_use, # pylint: disable=unused-argument
amax,
scale, scale_inv,
fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
scale,
scale_inv,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
activation_type):
activation_lu_out, _ = activation_fp8_dict[activation_type ]["fwd"](
x, amax, scale, scale_inv, fwd_dtype)
activation_lu_out, _ = activation_fp8_dict[activation_type]["fwd"](x, amax, scale, scale_inv,
fwd_dtype)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x, amax, scale, scale_inv)
return activation_lu_out, ctx
def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
activation_type, ctx, g):
def _activation_lu_fp8_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
activation_type,
ctx,
g):
x, amax, scale, scale_inv = ctx
activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"]
......@@ -139,6 +166,7 @@ def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype, # pylint: disable=unused
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
return ctx
_activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule)
......@@ -200,15 +228,12 @@ def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarr
epsilon: float, layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool):
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1,
bias_2, fp8_max, amax, scale, scale_inv,
fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
activation_type: Sequence[Union[str, Callable]], use_bias: bool):
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, amax, scale, scale_inv,
fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type,
use_bias)
return output
......@@ -256,6 +281,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
if not is_gated:
kernel_1 = jnp.squeeze(kernel_1, axis=-2)
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
......@@ -324,8 +354,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out,
dot_2_input_axes)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
casted_activation_lu_out, dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
......@@ -335,8 +365,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
activation_lu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias:
......@@ -348,7 +378,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape)
x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape, maybe_fp32_to_fm32)
return dot_2_output, ctx
......@@ -371,7 +401,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
is_gated = len(activation_type) > 1
......@@ -481,8 +511,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
xt_batch_dims_2 = xt_batch_dims if not is_gated \
else tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype,
(xt_batch_dims, xt_batch_dims_2),
dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# Expand act axis to match the shape with the given kernel_1
if not is_gated:
......@@ -490,14 +519,13 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
# (batch..., hidden_out) x (hidden_in, hidden_out)
if is_gated:
x_contracting_dims = ((min(x_contracting_dims),) + tuple(
i + 1 for i in x_contracting_dims), (1,2))
x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
(1, 2))
else:
x_contracting_dims = (x_contracting_dims, (1,))
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1,
dactivation_lu_scale_inv, kernel_1_scale_inv,
grad.dtype, x_contracting_dims,
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
kernel_1_scale_inv, grad.dtype, x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
......@@ -523,7 +551,8 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
fp8_max, amax, scale, scale_inv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment