"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "70d3251fdd8d8e10d754ff0d9de67527e8ea3bf8"
Unverified Commit eed4dfc6 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support FP8 training for Pipeline Parallelism when Micro-batch > 1 on Paxml. (#774)



* Support FP8 Meta Dtype (FM32) and Align FP8 Scale Update with PyTorch.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Modify with the feedback of code review
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Hiding FlaxFloatMeta32 inside fp8.py
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Make functions to be JAX tracable objects.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Rebased with mian.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update jax images for github workflow.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 7acb5e2b
...@@ -31,7 +31,7 @@ jobs: ...@@ -31,7 +31,7 @@ jobs:
name: 'JAX' name: 'JAX'
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
image: ghcr.io/nvidia/jax:latest image: ghcr.io/nvidia/jax:jax
options: --user root options: --user root
steps: steps:
- name: 'Checkout' - name: 'Checkout'
......
...@@ -103,12 +103,18 @@ def _fp8_dot_fwd_rule( ...@@ -103,12 +103,18 @@ def _fp8_dot_fwd_rule(
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims): contracting_dims):
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
lhs_contracting_dims, rhs_contracting_dims = contracting_dims lhs_contracting_dims, rhs_contracting_dims = contracting_dims
x_shape_suf = x.shape[min(lhs_contracting_dims):] x_shape_suf = x.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1] kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
assert x_shape_suf == kernel_shape_pre assert x_shape_suf == kernel_shape_pre
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
...@@ -130,7 +136,7 @@ def _fp8_dot_fwd_rule( ...@@ -130,7 +136,7 @@ def _fp8_dot_fwd_rule(
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
ctx = (casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax, ctx = (casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape) updated_kernel_amax, x.shape, kernel.shape, maybe_fp32_to_fm32)
return output, ctx return output, ctx
...@@ -138,7 +144,8 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p ...@@ -138,7 +144,8 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
lhs_contracting_dims, rhs_contracting_dims = contracting_dims lhs_contracting_dims, rhs_contracting_dims = contracting_dims
casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, \ casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, \
updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx updated_x_amax, updated_kernel_amax, x_shape, kernel_shape, \
maybe_fp32_to_fm32 = ctx
gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0) gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
...@@ -170,7 +177,7 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p ...@@ -170,7 +177,7 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax) amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax)
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0]) amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
return dgrad, wgrad, fp8_max, amax, scale, scale_inv return dgrad, wgrad, fp8_max, amax, scale, scale_inv
......
...@@ -11,6 +11,7 @@ from typing import Dict, Optional, Tuple, Union ...@@ -11,6 +11,7 @@ from typing import Dict, Optional, Tuple, Union
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import fp8_ops
from transformer_engine_jax import DType from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version from transformer_engine_jax import get_cublasLt_version
...@@ -67,6 +68,15 @@ def _format2dtypes(format_: Format): ...@@ -67,6 +68,15 @@ def _format2dtypes(format_: Format):
return jnp.bfloat16, jnp.bfloat16 return jnp.bfloat16, jnp.bfloat16
# fm32 is a custom dtype to specify the "add" rules as max operation.
# This is typically used in Pipeline Parallelism + "MiconBatching > 1",
# which is implemented via nn.scan. Without this custom dtype, nn.scan
# would sum gradients from all micro-batches, and this is not the expected
# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should
# be "MAX".
FlaxFloatMeta32 = fp8_ops.fm32
class FP8MetaPackage: class FP8MetaPackage:
""" """
A container that contains all required meta data for FP8 A container that contains all required meta data for FP8
...@@ -303,6 +313,42 @@ class FP8Helper: ...@@ -303,6 +313,42 @@ class FP8Helper:
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays) return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
@staticmethod
def generate_fp8_meta_dtype_converter_pair(*args):
"""
Generate a pair of conversion fun in-between fm32 and fp32.
"""
def identical_fun(*metas):
return metas
def fm32_to_fp32_fun(*metas):
for meta in metas:
assert meta.dtype == FlaxFloatMeta32
return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas]
def fp32_to_fm32_fun(*metas):
for meta in metas:
assert meta.dtype == jnp.float32
return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas]
# Make functions to be a vaild JAX type
partial_identical_fun = jax.tree_util.Partial(identical_fun)
partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun)
partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun)
if len(args) < 1:
return partial_identical_fun, partial_identical_fun
original_dtype = args[0].dtype
for arg in args:
assert arg.dtype == original_dtype
if original_dtype == FlaxFloatMeta32:
return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun
return partial_identical_fun, partial_identical_fun
@staticmethod @staticmethod
def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray: def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
""" """
......
...@@ -162,6 +162,11 @@ def _layernorm_fp8_dot_fwd_rule( ...@@ -162,6 +162,11 @@ def _layernorm_fp8_dot_fwd_rule(
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0] assert x.shape[-1] == kernel.shape[0]
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
...@@ -216,7 +221,7 @@ def _layernorm_fp8_dot_fwd_rule( ...@@ -216,7 +221,7 @@ def _layernorm_fp8_dot_fwd_rule(
ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax, ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims, updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
k_contracting_dims) k_contracting_dims, maybe_fp32_to_fm32)
return output, ctx return output, ctx
...@@ -234,7 +239,7 @@ def _layernorm_fp8_dot_bwd_rule( ...@@ -234,7 +239,7 @@ def _layernorm_fp8_dot_bwd_rule(
ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \ ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
updated_x_amax, updated_kernel_amax, \ updated_x_amax, updated_kernel_amax, \
x_shape, kernel_shape, mu, rsigma, x, gamma, \ x_shape, kernel_shape, mu, rsigma, x, gamma, \
x_contracting_dims, k_contracting_dims = ctx x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32 = ctx
ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1) ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
...@@ -282,7 +287,7 @@ def _layernorm_fp8_dot_bwd_rule( ...@@ -282,7 +287,7 @@ def _layernorm_fp8_dot_bwd_rule(
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0]) amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0]) amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
return dx, wgrad, \ return dx, wgrad, \
dgamma, dbeta, \ dgamma, dbeta, \
......
...@@ -26,27 +26,42 @@ from .layernorm import canonicalize_layernorm_type ...@@ -26,27 +26,42 @@ from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes from .sharding import with_sharding_constraint_by_logical_axes
activation_dict = { activation_dict = {
('gelu',): {'fwd': gelu, ('gelu',): {
"bwd": dgelu}, 'fwd': gelu,
('gelu', 'linear'): {'fwd': gated_gelu, "bwd": dgelu
'bwd': dgated_gelu}, },
('silu',): {'fwd': silu, ('gelu', 'linear'): {
"bwd": dsilu }, 'fwd': gated_gelu,
('silu', 'linear'): {'fwd': gated_silu, 'bwd': dgated_gelu
'bwd': dgated_silu} },
('silu',): {
'fwd': silu,
"bwd": dsilu
},
('silu', 'linear'): {
'fwd': gated_silu,
'bwd': dgated_silu
}
} }
activation_fp8_dict = { activation_fp8_dict = {
('gelu',): {'fwd': gelu_fp8, ('gelu',): {
'bwd': dgelu_dbias_cast_transpose}, 'fwd': gelu_fp8,
('gelu', 'linear'): {'fwd': gated_gelu_fp8, 'bwd': dgelu_dbias_cast_transpose
'bwd': dgated_gelu_cast_transpose}, },
('silu',): { 'fwd': silu_fp8, ('gelu', 'linear'): {
'bwd': dsilu_dbias_cast_transpose }, 'fwd': gated_gelu_fp8,
('silu', 'linear'): { 'fwd': gated_silu_fp8, 'bwd': dgated_gelu_cast_transpose
'bwd': dgated_silu_cast_transpose } },
('silu',): {
'fwd': silu_fp8,
'bwd': dsilu_dbias_cast_transpose
},
('silu', 'linear'): {
'fwd': gated_silu_fp8,
'bwd': dgated_silu_cast_transpose
}
} }
...@@ -59,6 +74,7 @@ def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable] ...@@ -59,6 +74,7 @@ def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]
output = _activation_lu(x, activation_type) output = _activation_lu(x, activation_type)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(1,)) @partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
...@@ -66,10 +82,12 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable ...@@ -66,10 +82,12 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable
return _output return _output
def _activation_lu_fwd_rule(x, activation_type): def _activation_lu_fwd_rule(x, activation_type):
fwd_output = activation_dict[activation_type]["fwd"](x) fwd_output = activation_dict[activation_type]["fwd"](x)
return fwd_output, (x,) return fwd_output, (x,)
def _activation_lu_bwd_rule(activation_type, ctx, g): def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx x, = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
...@@ -78,11 +96,12 @@ def _activation_lu_bwd_rule(activation_type, ctx, g): ...@@ -78,11 +96,12 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
dx = jnp.reshape(dx, x.shape) dx = jnp.reshape(dx, x.shape)
return (dx,) return (dx,)
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule) _activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
scale_inv: jnp.ndarray, fwd_dtype:jnp.dtype, bwd_dtype: jnp.dtype, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]): activation_type: Sequence[Union[str, Callable]]):
""" """
Activation Unit Activation Unit
...@@ -91,39 +110,47 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, ...@@ -91,39 +110,47 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype) dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv, fwd_dtype,
scale, scale_inv, fwd_dtype, bwd_dtype, activation_type) bwd_dtype, activation_type)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(6,7,8))
def _activation_lu_fp8(x: jnp.ndarray, @partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray, def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]): activation_type: Sequence[Union[str, Callable]]):
output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv,
scale, scale_inv, fwd_dtype, bwd_dtype, fwd_dtype, bwd_dtype, activation_type)
activation_type)
return output return output
def _activation_lu_fp8_fwd_rule(x,
def _activation_lu_fp8_fwd_rule(
x,
dx_trans_no_use, # pylint: disable=unused-argument dx_trans_no_use, # pylint: disable=unused-argument
dbias_no_use, # pylint: disable=unused-argument dbias_no_use, # pylint: disable=unused-argument
amax, amax,
scale, scale_inv, scale,
fwd_dtype, bwd_dtype, # pylint: disable=unused-argument scale_inv,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
activation_type): activation_type):
activation_lu_out, _ = activation_fp8_dict[activation_type ]["fwd"]( activation_lu_out, _ = activation_fp8_dict[activation_type]["fwd"](x, amax, scale, scale_inv,
x, amax, scale, scale_inv, fwd_dtype) fwd_dtype)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv) activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x, amax, scale, scale_inv) ctx = (x, amax, scale, scale_inv)
return activation_lu_out, ctx return activation_lu_out, ctx
def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
activation_type, ctx, g): def _activation_lu_fp8_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
activation_type,
ctx,
g):
x, amax, scale, scale_inv = ctx x, amax, scale, scale_inv = ctx
activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"] activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"]
...@@ -139,6 +166,7 @@ def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype, # pylint: disable=unused ...@@ -139,6 +166,7 @@ def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype, # pylint: disable=unused
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv) ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
return ctx return ctx
_activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule) _activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule)
...@@ -200,15 +228,12 @@ def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarr ...@@ -200,15 +228,12 @@ def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarr
epsilon: float, layernorm_input_axes: Tuple[str, ...], epsilon: float, layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str, ffn1_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]], use_bias: bool):
use_bias: bool): output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, amax, scale, scale_inv,
bias_2, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, layernorm_input_axes,
fwd_dtype, bwd_dtype, layernorm_type, dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type,
zero_centered_gamma, epsilon, use_bias)
layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
return output return output
...@@ -256,6 +281,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -256,6 +281,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
if not is_gated: if not is_gated:
kernel_1 = jnp.squeeze(kernel_1, axis=-2) kernel_1 = jnp.squeeze(kernel_1, axis=-2)
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
...@@ -324,8 +354,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -324,8 +354,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale, activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype) activation_lu_out_scale_inv, fwd_dtype)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out, casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
dot_2_input_axes) casted_activation_lu_out, dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx] kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
...@@ -335,8 +365,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -335,8 +365,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
# (batch..., hidden_in) x (hidden_out, hidden_in) # (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2, dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
activation_lu_out_scale_inv, activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)), (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias: if use_bias:
...@@ -348,7 +378,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -348,7 +378,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1, ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape) x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape, maybe_fp32_to_fm32)
return dot_2_output, ctx return dot_2_output, ctx
...@@ -371,7 +401,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -371,7 +401,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \ x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
is_gated = len(activation_type) > 1 is_gated = len(activation_type) > 1
...@@ -481,8 +511,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -481,8 +511,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
xt_batch_dims_2 = xt_batch_dims if not is_gated \ xt_batch_dims_2 = xt_batch_dims if not is_gated \
else tuple(i + 1 for i in xt_batch_dims) else tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype, dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
(xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# Expand act axis to match the shape with the given kernel_1 # Expand act axis to match the shape with the given kernel_1
if not is_gated: if not is_gated:
...@@ -490,14 +519,13 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -490,14 +519,13 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
# (batch..., hidden_out) x (hidden_in, hidden_out) # (batch..., hidden_out) x (hidden_in, hidden_out)
if is_gated: if is_gated:
x_contracting_dims = ((min(x_contracting_dims),) + tuple( x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
i + 1 for i in x_contracting_dims), (1,2)) (1, 2))
else: else:
x_contracting_dims = (x_contracting_dims, (1,)) x_contracting_dims = (x_contracting_dims, (1,))
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
dactivation_lu_scale_inv, kernel_1_scale_inv, kernel_1_scale_inv, grad.dtype, x_contracting_dims,
grad.dtype, x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
...@@ -523,7 +551,8 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -523,7 +551,8 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax) amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0]) amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \ return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
fp8_max, amax, scale, scale_inv fp8_max, amax, scale, scale_inv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment