Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -14,15 +14,18 @@ from . import cpp_extensions as tex
class SoftmaxType(Enum):
"""SoftmaxType."""
SCALED = "scaled"
SCALED_MASKED = "scaled_masked"
SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"
def softmax(logits: jnp.ndarray,
def softmax(
logits: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0,
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED):
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED,
):
"""
Softmax wrapper
"""
......@@ -50,7 +53,7 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
softmax_output, = ctx
(softmax_output,) = ctx
if softmax_type is SoftmaxType.SCALED_MASKED:
dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, scale_factor)
......
......@@ -6,6 +6,7 @@
# pylint: disable=wrong-import-position,wrong-import-order
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import
......
......@@ -13,6 +13,7 @@ from transformer_engine import transformer_engine_paddle as tex
class FP8FwdTensors(Enum):
"""Used as named indices on the `scale`, `scale_inv`,
and `amax` tensors in the `FP8TensorMeta` class."""
GEMM1_INPUT = 0
GEMM1_WEIGHT = 1
GEMM1_OUTPUT = 2
......@@ -24,6 +25,7 @@ class FP8FwdTensors(Enum):
class FP8BwdTensors(Enum):
"""Used as named indices on the `scale`, `scale_inv`,
and `amax` tensors in the `FP8TensorMeta` class."""
GRAD_OUTPUT1 = 0
GRAD_INPUT1 = 1
GRAD_OUTPUT2 = 2
......@@ -51,7 +53,7 @@ GemmParallelModes = ("row", "column", None)
dist_group_type = paddle.distributed.collective.Group
RecomputeFunctionNames = ('unpack', 'backward')
RecomputeFunctionNames = ("unpack", "backward")
AttnBiasType = {
"no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS,
......
......@@ -66,8 +66,9 @@ def gemm(
bias = bias if use_bias else None
assert A.dtype == dtype and B.dtype == dtype, \
f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
assert (
A.dtype == dtype and B.dtype == dtype
), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}"
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out.dtype]
if use_bias:
......@@ -270,7 +271,9 @@ def cast_transpose(
dtype=paddle.uint8,
)
else:
assert transpose_out.shape == [inp.shape[1], inp.shape[0]
assert transpose_out.shape == [
inp.shape[1],
inp.shape[0],
], "Transposed output shape does not match input shape."
assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype."
......@@ -348,7 +351,9 @@ def swiglu(
)
def swiglu_pd(inp: paddle.Tensor,) -> paddle.Tensor:
def swiglu_pd(
inp: paddle.Tensor,
) -> paddle.Tensor:
"""Native SWIGLU"""
gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1)
out = F.silu(gate_out) * up_out
......@@ -423,11 +428,19 @@ def layernorm_fwd_fp8(
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""LayerNorm with FP8 output"""
out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(inp, weight, bias, fp8_meta_tensor.scale,
out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(
inp,
weight,
bias,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps,
fp8_tensor.value, int(otype), sm_margin,
zero_centered_gamma)
fp8_meta_tensor.scale_inv,
eps,
fp8_tensor.value,
int(otype),
sm_margin,
zero_centered_gamma,
)
return out, mu, rsigma
......@@ -480,10 +493,18 @@ def rmsnorm_fwd_fp8(
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""RMSNorm with FP8 output"""
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(
inp,
weight,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps, fp8_tensor.value,
int(otype), sm_margin, zero_centered_gamma)
fp8_meta_tensor.scale_inv,
eps,
fp8_tensor.value,
int(otype),
sm_margin,
zero_centered_gamma,
)
return out, rsigma
......@@ -533,8 +554,10 @@ def fused_attn_fwd_qkvpacked(
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
......@@ -546,17 +569,23 @@ def fused_attn_fwd_qkvpacked(
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [1, h, max_seqlen, max_seqlen
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."
assert Bias.shape == [
1,
h,
max_seqlen,
max_seqlen,
], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert Bias.dtype == qkv.dtype, "bias tensor must be in the same dtype as qkv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA -
1) // BACKEND_F16m512_THREADS_PER_CTA
rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA - 1
) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
......@@ -571,15 +600,18 @@ def fused_attn_fwd_qkvpacked(
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype='float32')
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype="float32")
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
rng_state = paddle.empty(shape=[
rng_state = paddle.empty(
shape=[
2,
], dtype=paddle.int64)
],
dtype=paddle.int64,
)
# execute kernel
tex.te_fused_attn_fwd_qkvpacked(
......@@ -625,8 +657,10 @@ def fused_attn_bwd_qkvpacked(
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed QKV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
......@@ -636,7 +670,8 @@ def fused_attn_bwd_qkvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if set_zero:
......@@ -694,9 +729,12 @@ def fused_attn_fwd_kvpacked(
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
......@@ -710,17 +748,23 @@ def fused_attn_fwd_kvpacked(
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [1, h, max_seqlen_q, max_seqlen_kv
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."
assert Bias.shape == [
1,
h,
max_seqlen_q,
max_seqlen_kv,
], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as q and kv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA -
1) // BACKEND_F16m512_THREADS_PER_CTA
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1
) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
......@@ -735,15 +779,18 @@ def fused_attn_fwd_kvpacked(
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32")
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
rng_state = paddle.empty(shape=[
rng_state = paddle.empty(
shape=[
2,
], dtype=paddle.int64)
],
dtype=paddle.int64,
)
# execute kernel
tex.te_fused_attn_fwd_kvpacked(
......@@ -797,9 +844,12 @@ def fused_attn_bwd_kvpacked(
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed KV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
......@@ -811,7 +861,8 @@ def fused_attn_bwd_kvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if set_zero:
......@@ -875,11 +926,15 @@ def fused_attn_fwd(
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for unpacked QKV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (qkv_layout == "bshd_bshd_bshd"
assert (
qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
......@@ -891,18 +946,23 @@ def fused_attn_fwd(
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [
1, h, max_seqlen_q, max_seqlen_kv
]), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as qkv."
assert Bias.shape == [
1,
h,
max_seqlen_q,
max_seqlen_kv,
], "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as qkv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA -
1) // BACKEND_F16m512_THREADS_PER_CTA
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1
) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
......@@ -917,15 +977,18 @@ def fused_attn_fwd(
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32")
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
rng_state = paddle.empty(shape=[
rng_state = paddle.empty(
shape=[
2,
], dtype=paddle.int64)
],
dtype=paddle.int64,
)
# execute kernel
tex.te_fused_attn_fwd(
......@@ -978,11 +1041,15 @@ def fused_attn_bwd(
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed KV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (qkv_layout == "bshd_bshd_bshd"
assert (
qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
......@@ -992,7 +1059,8 @@ def fused_attn_bwd(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if set_zero:
......@@ -1041,7 +1109,7 @@ def scaled_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled softmax forward"""
"""scaled softmax forward"""
return tex.te_scaled_softmax_forward(inp, scale_factor)
......@@ -1050,7 +1118,7 @@ def scaled_softmax_backward(
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled softmax backward"""
"""scaled softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
......@@ -1060,7 +1128,7 @@ def scaled_masked_softmax_forward(
mask: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled masked softmax forward"""
"""scaled masked softmax forward"""
return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor)
......@@ -1070,7 +1138,7 @@ def scaled_masked_softmax_backward(
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled masked softmax backward"""
"""scaled masked softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
......@@ -1079,7 +1147,7 @@ def scaled_upper_triang_masked_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled upper triang masked softmax forward"""
"""scaled upper triang masked softmax forward"""
return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor)
......@@ -1088,6 +1156,6 @@ def scaled_upper_triang_masked_softmax_backward(
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
""" scaled upper triang masked softmax backward"""
"""scaled upper triang masked softmax backward"""
tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
......@@ -38,12 +38,10 @@ paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const pad
bool init_to_zeros) {
auto size = shape.ndim;
if (size == 2 && init_to_zeros) {
return paddle::zeros(
{static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
return paddle::zeros({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
Nvte2PaddleDType(type), place);
} else if (size == 2) {
return paddle::empty(
{static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
return paddle::empty({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
Nvte2PaddleDType(type), place);
} else if (size == 1 && init_to_zeros) {
return paddle::zeros({static_cast<int64_t>(shape.data[0])}, Nvte2PaddleDType(type), place);
......
......@@ -5,13 +5,7 @@
************************************************************************/
#pragma once
#include <cstdlib>
#include <vector>
#include <cublasLt.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/all_context.h"
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/fused_attn.h>
......@@ -22,7 +16,13 @@
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <cstdlib>
#include <vector>
#include "common/util/logging.h"
#include "paddle/extension.h"
#include "paddle/phi/backends/all_context.h"
namespace transformer_engine {
namespace paddle_ext {
......@@ -129,12 +129,12 @@ inline DType Int2NvteDType(int64_t dtype) {
inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim);
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads,
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim);
return fused_attention_backend;
}
......
......@@ -103,8 +103,8 @@ void te_cast_transpose(const paddle::Tensor &input, const paddle::Tensor &scale,
void *amax_data = GetDataPtr<float>(amax, index);
void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index));
void *scale_inv_data = GetDataPtr<float>(scale_inv, index);
auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype), amax_data,
scale_data, scale_inv_data);
auto output_transpose_cu = MakeNvteTensor(output_transpose.data(), {N, M}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
......@@ -125,8 +125,8 @@ std::vector<paddle::Tensor> te_cast_transpose_bgrad(const paddle::Tensor &grad_o
auto grad_bias =
paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place());
auto grad_output_cast = paddle::empty_like(grad_output, Nvte2PaddleDType(Int2NvteDType(otype)),
grad_output.place());
auto grad_output_cast =
paddle::empty_like(grad_output, Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place());
auto grad_output_transpose =
paddle::empty({grad_output.shape()[1], grad_output.shape()[0]},
Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place());
......@@ -179,8 +179,7 @@ void te_gemm(const paddle::Tensor &A, const paddle::optional<paddle::Tensor> &A_
auto te_bias = MakeNvteTensor(const_cast<void *>(GetOptionalDataPtr(bias)), GetShapeArray(bias),
Int2NvteDType(bias_type));
DType gelu_dtype =
pre_gelu_out ? Paddle2NvteDType(pre_gelu_out->dtype()) : Int2NvteDType(D_type);
DType gelu_dtype = pre_gelu_out ? Paddle2NvteDType(pre_gelu_out->dtype()) : Int2NvteDType(D_type);
auto te_pre_gelu_out =
MakeNvteTensor(GetOptionalDataPtr(pre_gelu_out), GetShapeArray(pre_gelu_out), gelu_dtype);
auto te_workspace =
......@@ -294,8 +293,7 @@ std::vector<paddle::Tensor> te_cast_transpose_bgrad_dgelu(const paddle::Tensor &
auto grad_bias =
paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place());
auto dgelu =
paddle::empty_like(grad_output, Nvte2PaddleDType(DType::kByte), grad_output.place());
auto dgelu = paddle::empty_like(grad_output, Nvte2PaddleDType(DType::kByte), grad_output.place());
auto dgelu_transpose = paddle::empty({grad_output.shape()[1], grad_output.shape()[0]},
Nvte2PaddleDType(DType::kByte), grad_output.place());
......@@ -345,8 +343,7 @@ std::vector<paddle::Tensor> te_layernorm_fwd_fp8(const paddle::Tensor &input,
auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto mu = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto beta_cu = MakeNvteTensor(bias);
......@@ -389,8 +386,7 @@ std::vector<paddle::Tensor> te_layernorm_fwd(const paddle::Tensor &input,
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto mu = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto beta_cu = MakeNvteTensor(bias);
......@@ -442,9 +438,9 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad
// This call populates tensors with the required config.
const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(),
dz.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(),
num_sm - sm_margin, workspace.data(), barrier.data());
// Alloc space for Tensors.
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place());
......@@ -457,9 +453,9 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad
dbeta_part = MakeNvteTensor(dbeta_part_data.data(), dbeta_part.shape(), dbeta_part.dtype());
// Actual call to bwd kernel.
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(),
dz.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(),
num_sm - sm_margin, workspace.data(), barrier.data());
return {dx, dgamma, dbeta};
}
......@@ -467,8 +463,7 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad
std::vector<paddle::Tensor> te_rmsnorm_fwd(const paddle::Tensor &input,
const paddle::Tensor &weight, float eps, int64_t otype,
int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
......@@ -476,8 +471,7 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd(const paddle::Tensor &input,
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype));
......@@ -511,8 +505,7 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
paddle::Tensor &scale_inv, // NOLINT
float eps, int64_t index, int64_t otype,
int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
......@@ -520,8 +513,7 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto z_cu = MakeNvteTensor(
......@@ -553,8 +545,7 @@ std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl
const paddle::Tensor &rsigma,
const paddle::Tensor &gamma, int64_t sm_margin,
bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm.");
auto dx = paddle::empty_like(x, x.dtype(), x.place());
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
......@@ -652,25 +643,24 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s =
reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *output_s = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -740,9 +730,9 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
QKV.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
......@@ -752,9 +742,9 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
QKV.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -828,30 +818,27 @@ void te_fused_attn_fwd_kvpacked(
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s =
reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *output_s = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -907,9 +894,9 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_aux_tensor_pack.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape = std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q),
static_cast<size_t>(max_seqlen_kv)});
output_s->data.shape =
std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q), static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
......@@ -926,26 +913,26 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -1014,28 +1001,27 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s =
reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *output_s = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -1092,9 +1078,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_aux_tensor_pack.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape = std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q),
static_cast<size_t>(max_seqlen_kv)});
output_s->data.shape =
std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q), static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
......@@ -1111,12 +1097,11 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
......@@ -1125,12 +1110,11 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
......@@ -1141,8 +1125,8 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
std::vector<paddle::Tensor> te_scaled_softmax_forward(const paddle::Tensor &input,
float scale_factor) {
NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK((input.dtype() == paddle::DataType::FLOAT16) ||
(input.dtype() == paddle::DataType::BFLOAT16),
NVTE_CHECK(
(input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
const int batches = input.shape()[0];
......@@ -1190,8 +1174,8 @@ std::vector<paddle::Tensor> te_scaled_masked_softmax_forward(const paddle::Tenso
float scale_factor) {
NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK(mask.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK((input.dtype() == paddle::DataType::FLOAT16) ||
(input.dtype() == paddle::DataType::BFLOAT16),
NVTE_CHECK(
(input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
const int batches = input.shape()[0];
......@@ -1243,8 +1227,8 @@ void te_scaled_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT
std::vector<paddle::Tensor> te_scaled_upper_triang_masked_softmax_forward(
const paddle::Tensor &input, float scale_factor) {
NVTE_CHECK(input.shape().size() == 3, "expected 3D tensor");
NVTE_CHECK((input.dtype() == paddle::DataType::FLOAT16) ||
(input.dtype() == paddle::DataType::BFLOAT16),
NVTE_CHECK(
(input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
const int attn_batches = input.shape()[0];
......@@ -1291,34 +1275,23 @@ constexpr int BLOCK_SIZE = 512;
void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
paddle::Tensor &scale, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
const paddle::Tensor &non_weight_mask,
int64_t fp8_dtype,
float margin,
const std::string &amax_compute) {
const paddle::Tensor &non_weight_mask, int64_t fp8_dtype,
float margin, const std::string &amax_compute) {
auto amax_history_ = MakeNvteTensor(amax_history);
auto scale_ = MakeNvteTensor(scale);
auto scale_inv_ = MakeNvteTensor(scale_inv);
const auto non_weight_mask_ = MakeNvteTensor(non_weight_mask);
nvte_delayed_scaling_recipe_amax_and_scale_update(
amax_history_.data(),
scale_.data(),
scale_inv_.data(),
non_weight_mask_.data(),
amax_history_.data(),
scale_.data(),
scale_inv_.data(),
amax_compute.c_str(),
static_cast<NVTEDType>(fp8_dtype),
margin,
amax_history.stream());
amax_history_.data(), scale_.data(), scale_inv_.data(), non_weight_mask_.data(),
amax_history_.data(), scale_.data(), scale_inv_.data(), amax_compute.c_str(),
static_cast<NVTEDType>(fp8_dtype), margin, amax_history.stream());
}
void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT
const paddle::Tensor &amax) {
// Copy amax to history[0]
NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(),
amax.numel() * SizeOf(amax.dtype()), cudaMemcpyDeviceToDevice,
amax.stream()));
NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(), amax.numel() * SizeOf(amax.dtype()),
cudaMemcpyDeviceToDevice, amax.stream()));
}
__global__ __launch_bounds__(BLOCK_SIZE) void mask_to_actual_seqlens_kernel(
......@@ -1389,8 +1362,7 @@ void mask_to_cu_seqlens(const paddle::Tensor &mask,
}
mask_to_actual_seqlens_kernel<<<mask.shape()[0], BLOCK_SIZE, 0, mask.stream()>>>(
mask.data<bool>(), q_cu_seqlen.data<int32_t>(),
reinterpret_cast<int32_t *>(GetOptionalDataPtr(kv_cu_seqlen)), q_seqlen, kv_seqlen,
need_kv);
reinterpret_cast<int32_t *>(GetOptionalDataPtr(kv_cu_seqlen)), q_seqlen, kv_seqlen, need_kv);
// q_cu_seqlen shape: [bs+1], assume bs is not too large (<=512), so we can use a single block
// to do prefix sum
NVTE_CHECK(q_cu_seqlen.numel() - 1 <= BLOCK_SIZE, "batch size too large, kernel may fail");
......
......@@ -17,26 +17,25 @@ from paddle.distributed.fleet.layers.mpu import mp_ops
from .constants import dist_group_type
_weight_split_axis = {
'transformer_engine': {
'row': 1,
'column': 0
},
'paddle': {
'row': 0,
'column': 1
}
"transformer_engine": {"row": 1, "column": 0},
"paddle": {"row": 0, "column": 1},
}
def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],
enable_tp: bool = True) -> Tuple[Union[dist_group_type, None], int]:
def get_tp_group_and_world_size(
tp_group: Union[dist_group_type, None], enable_tp: bool = True
) -> Tuple[Union[dist_group_type, None], int]:
"""Get TP group and world size using Fleet API"""
if not (paddle.distributed.is_initialized() and enable_tp):
return None, 1
model_parallel_group = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
if tp_group is None else tp_group)
world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
if tp_group is None else tp_group.nranks)
model_parallel_group = (
tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() if tp_group is None else tp_group
)
world_size = (
tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
if tp_group is None
else tp_group.nranks
)
"""
When using TP, the NCCL communication needs to be scheduled
before the GEMM for a guaranteed overlap. From the host side
......@@ -47,8 +46,10 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],
"""
num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
if num_cuda_work_queues != 1:
warnings.warn("To guarantee overlapping TP and SP collectives with the backward"
"GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1")
warnings.warn(
"To guarantee overlapping TP and SP collectives with the backward"
"GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
)
return model_parallel_group, world_size
......@@ -73,8 +74,9 @@ def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) ->
tensor.split_axis = axis
def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool,
parallel_mode: Optional[str], backend: str) -> None:
def set_weight_tensor_dist_attr(
tensor: paddle.Tensor, is_parallel: bool, parallel_mode: Optional[str], backend: str
) -> None:
"""Set distributed attributes for the weight tensor"""
if not is_parallel or parallel_mode is None:
return
......@@ -149,17 +151,15 @@ def reduce_scatter(
parallelism = tp_group.nranks
output_shape = input_.shape
assert (
input_.shape[0] % parallelism == 0
), f"Input sequence length {input_.shape[0]} can't be divided " \
assert input_.shape[0] % parallelism == 0, (
f"Input sequence length {input_.shape[0]} can't be divided "
f"exactly by sequence parallelism {parallelism}"
)
output_shape[0] = output_shape[0] // parallelism
output = paddle.empty(shape=output_shape, dtype=input_.dtype)
wait_handle = paddle.distributed.stream.reduce_scatter(output,
input_,
op=paddle.distributed.ReduceOp.SUM,
group=tp_group,
sync_op=sync_op)
wait_handle = paddle.distributed.stream.reduce_scatter(
output, input_, op=paddle.distributed.ReduceOp.SUM, group=tp_group, sync_op=sync_op
)
if sync_op:
return output, None
return output, wait_handle
......
......@@ -15,7 +15,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer
__all__ = ['fp8_autocast']
__all__ = ["fp8_autocast"]
# FP8 support
_is_fp8_available = None
......@@ -124,8 +124,13 @@ class FP8State:
fp8_group: Optional[dist_group_type],
) -> None:
"""Called when entering 'fp8_autocast'"""
self.saved_states = (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe,
self._fp8_distributed_group, self._is_first_fp8_module)
self.saved_states = (
self._fp8_enabled,
self._fp8_calibration,
self._fp8_recipe,
self._fp8_distributed_group,
self._is_first_fp8_module,
)
self._fp8_enabled = enabled
self._fp8_calibration = calibrating
......@@ -140,8 +145,13 @@ class FP8State:
def exit(self):
"""Called when exiting 'fp8_autocast'"""
# Restore saved states
(self._fp8_enabled, self._fp8_calibration, self._fp8_recipe, self._fp8_distributed_group,
self._is_first_fp8_module) = self.saved_states
(
self._fp8_enabled,
self._fp8_calibration,
self._fp8_recipe,
self._fp8_distributed_group,
self._is_first_fp8_module,
) = self.saved_states
self._fp8_autocast_depth -= 1
......@@ -214,8 +224,9 @@ def fp8_autocast(
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (fp8_recipe.fp8_format == Format.HYBRID
and fprop_tensor):
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
......@@ -241,14 +252,17 @@ def amax_and_scale_update(
non_weight_mask=non_weight_mask,
fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)),
margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute)
amax_compute=amax_compute,
)
else:
raise ValueError("We only support the fp8 recipe with 'max' or 'most_recent' "
raise ValueError(
"We only support the fp8 recipe with 'max' or 'most_recent' "
"amax_compute_algo and default scaling_factor_compute_algo at this "
"moment.")
"moment."
)
class FP8TensorMeta():
class FP8TensorMeta:
"""Holds FP8 scaling and amax history for FP8 layers"""
def __init__(self, is_forward: bool):
......@@ -281,20 +295,22 @@ class FP8TensorMeta():
self.amax_history = self.amax_history[:amax_history_len]
elif amax_history_len > curr_len:
extra_rows = amax_history_len - curr_len
self.amax_history = paddle.concat([
self.amax_history = paddle.concat(
[
self.amax_history,
paddle.zeros((extra_rows, num_fp8_tensors), dtype='float32')
paddle.zeros((extra_rows, num_fp8_tensors), dtype="float32"),
],
axis=0)
axis=0,
)
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = (num_gemms * 3 if self.is_forward else num_gemms * 2)
num_fp8_tensors = num_gemms * 3 if self.is_forward else num_gemms * 2
self.scale = paddle.ones(num_fp8_tensors, dtype='float32')
self.scale_inv = paddle.ones(num_fp8_tensors, dtype='float32')
self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype='float32')
self.scale = paddle.ones(num_fp8_tensors, dtype="float32")
self.scale_inv = paddle.ones(num_fp8_tensors, dtype="float32")
self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype="float32")
self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms)
self.is_initialized = True
......@@ -303,16 +319,16 @@ class FP8TensorMeta():
"""Convert FP8 meta tensors to numpy."""
assert self.is_initialized, "FP8TensorMeta is not initialized yet."
return {
'scale': self.scale.numpy(),
'scale_inv': self.scale_inv.numpy(),
'amax_history': self.amax_history.numpy(),
"scale": self.scale.numpy(),
"scale_inv": self.scale_inv.numpy(),
"amax_history": self.amax_history.numpy(),
}
def from_numpy(self, data: Dict[str, np.array]):
"""Set FP8 meta tensors from numpy"""
self.scale = paddle.to_tensor(data['scale'])
self.scale_inv = paddle.to_tensor(data['scale_inv'])
self.amax_history = paddle.to_tensor(data['amax_history'])
self.scale = paddle.to_tensor(data["scale"])
self.scale_inv = paddle.to_tensor(data["scale_inv"])
self.amax_history = paddle.to_tensor(data["amax_history"])
num_fp8_tensors = self.scale.shape[0]
num_gemms = num_fp8_tensors // 3 if self.is_forward else num_fp8_tensors // 2
......
......@@ -49,7 +49,7 @@ class FP8MetaBufferBase(ABC):
def _execute_deletion(self) -> None:
"""Delete the key from global amax buffer."""
if (self._buffer_delete_key is not None and self._buffer_delete_key in self._data):
if self._buffer_delete_key is not None and self._buffer_delete_key in self._data:
del self._data[self._buffer_delete_key]
def _wait_handle_and_split(
......@@ -137,11 +137,12 @@ class FP8MetaBufferBase(ABC):
fp8_meta[buffer_position_key] = len(self._data[buffer_key]) - 1
# Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(self._data[buffer_key]) - 1, \
"Same module is being invoked more than once inside an `fp8_autocast` " \
"region when using FP8 with amax reduction. This behavior is currently " \
"unsupported. For more details and correct usage, please see " \
assert fp8_meta[buffer_position_key] == len(self._data[buffer_key]) - 1, (
"Same module is being invoked more than once inside an `fp8_autocast` "
"region when using FP8 with amax reduction. This behavior is currently "
"unsupported. For more details and correct usage, please see "
"https://github.com/NVIDIA/TransformerEngine/pull/93."
)
def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None:
"""Populate current amax with the correct location from buffer."""
......@@ -156,7 +157,8 @@ class FP8MetaBufferBase(ABC):
# Copy amax to amax_history[0]
tex.update_latest_amax_history_inplace(
_history=fp8_meta[fp8_meta_tensor_key].amax_history,
amax=self._data[amax_buffer_key][fp8_meta[buffer_position_key]])
amax=self._data[amax_buffer_key][fp8_meta[buffer_position_key]],
)
def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None:
"""Delete this amax key from global buffer during autocast end."""
......@@ -224,7 +226,7 @@ class FP8MetaFwdBuffer(FP8MetaBufferBase):
Called at FP8 autocast end.
Performs AMAX reduction and delete unused buffer entries.
"""
if hasattr(self, '_amax_global_reduce_func') and callable(self._amax_global_reduce_func):
if hasattr(self, "_amax_global_reduce_func") and callable(self._amax_global_reduce_func):
self._amax_reduce_wait_func = self._amax_global_reduce_func()
self._execute_deletion()
......@@ -270,7 +272,7 @@ class FP8RecomputeBuffer:
@staticmethod
def get_buffer_position_key():
"""Returns the key (in fp8_meta) for recompute buffer position"""
return 'recompute_buffer_pos'
return "recompute_buffer_pos"
def stash_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None:
"""Stash the scaling factors and amaxes for recompute"""
......@@ -308,11 +310,13 @@ class FP8RecomputeBuffer:
@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
assert "updated_amax_history_fwd" in fp8_meta, "Recompute internal error." \
" If you are not using recompute, please check if" \
" the forward function is called from one of these functions: " \
f"{RecomputeFunctionNames}. If so, consider change the function name " \
assert "updated_amax_history_fwd" in fp8_meta, (
"Recompute internal error."
" If you are not using recompute, please check if"
" the forward function is called from one of these functions: "
f"{RecomputeFunctionNames}. If so, consider change the function name "
"or set NVTE_DISABLE_RECOMPUTE=1."
)
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
......@@ -10,6 +10,7 @@ from typing import Optional, Tuple, Union
import paddle
import paddle.nn.functional as F
try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
......@@ -19,8 +20,14 @@ from transformer_engine import transformer_engine_paddle as tex
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
from ..constants import (AttnTypes, TE_DType, AttnBiasType, AttnMaskType, FusedAttnBackend,
dist_group_type)
from ..constants import (
AttnTypes,
TE_DType,
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
dist_group_type,
)
from ..cpp_extensions import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
......@@ -72,8 +79,9 @@ class RotaryPositionEmbedding(paddle.nn.Layer):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.inv_freq = 1.0 / (10000**(paddle.cast(paddle.arange(0, dim, 2), dtype='float32') /
self.dim))
self.inv_freq = 1.0 / (
10000 ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / self.dim)
)
self._set_cos_sin_cache(seq_len=max_position_embeddings)
def _set_cos_sin_cache(self, seq_len):
......@@ -104,8 +112,8 @@ class RotaryPositionEmbedding(paddle.nn.Layer):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
......@@ -114,8 +122,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
if position_ids is None:
# Note: Only for LlamaForCausalLMPipe model pretraining
cos = cos[:, :q.shape[1], :, :] # [bs, seq_len, 1, dim]
sin = sin[:, :q.shape[1], :, :] # [bs, seq_len, 1, dim]
cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
else:
cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
......@@ -130,9 +138,22 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed QKV input"""
@staticmethod
def forward(ctx, qkv, cu_seqlens, attn_bias, max_seqlen, attn_scale, qkv_dtype, dropout_p,
set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training,
fused_attention_backend):
def forward(
ctx,
qkv,
cu_seqlens,
attn_bias,
max_seqlen,
attn_scale,
qkv_dtype,
dropout_p,
set_zero,
qkv_layout,
attn_bias_type,
attn_mask_type,
is_training,
fused_attention_backend,
):
"""Forward function for FusedAttention with packed QKV input"""
out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
qkv,
......@@ -167,11 +188,23 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
def backward(ctx, d_out):
"""Backward function for FusedAttention with packed QKV input"""
qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor()
dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out, softmax_aux,
ctx.fused_attention_backend, ctx.max_seqlen,
ctx.qkv_dtype, ctx.attn_scale, ctx.dropout_p,
ctx.set_zero, ctx.qkv_layout, ctx.attn_bias_type,
ctx.attn_mask_type)
dqkv, *rest = fused_attn_bwd_qkvpacked(
qkv,
cu_seqlens,
rng_state,
out,
d_out,
softmax_aux,
ctx.fused_attention_backend,
ctx.max_seqlen,
ctx.qkv_dtype,
ctx.attn_scale,
ctx.dropout_p,
ctx.set_zero,
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
......@@ -184,14 +217,44 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed KV input"""
@staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
attn_mask_type, is_training, fused_attention_backend):
def forward(
ctx,
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
attn_bias,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
qkv_dtype,
dropout_p,
set_zero,
qkv_layout,
attn_bias_type,
attn_mask_type,
is_training,
fused_attention_backend,
):
"""Forward function for FusedAttention with packed KV input"""
out, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
q, kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, qkv_dtype,
fused_attention_backend, attn_bias, attn_scale, dropout_p, set_zero, qkv_layout,
attn_bias_type, attn_mask_type)
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
is_training,
max_seqlen_q,
max_seqlen_kv,
qkv_dtype,
fused_attention_backend,
attn_bias,
attn_scale,
dropout_p,
set_zero,
qkv_layout,
attn_bias_type,
attn_mask_type,
)
ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
ctx.max_seqlen_q = max_seqlen_q
......@@ -211,12 +274,26 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
def backward(ctx, d_out):
"""Backward function for FusedAttention with packed KV input"""
q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
d_out, softmax_aux, ctx.fused_attention_backend,
ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
ctx.qkv_layout, ctx.attn_bias_type,
ctx.attn_mask_type)
dq, dkv, *rest = fused_attn_bwd_kvpacked(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
rng_state,
out,
d_out,
softmax_aux,
ctx.fused_attention_backend,
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
ctx.qkv_dtype,
ctx.attn_scale,
ctx.dropout_p,
ctx.set_zero,
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
)
# if no_bias, return dq, dkv
if ctx.attn_bias_type == "no_bias":
......@@ -229,15 +306,46 @@ class FusedAttnFunc(paddle.autograd.PyLayer):
"""Function for FusedAttention with separate Q, K, V tensors"""
@staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
attn_mask_type, is_training, fused_attention_backend):
def forward(
ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
attn_bias,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
qkv_dtype,
dropout_p,
set_zero,
qkv_layout,
attn_bias_type,
attn_mask_type,
is_training,
fused_attention_backend,
):
"""Forward function for FusedAttention with separate Q, K, V tensors"""
out, softmax_aux, rng_state = fused_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_kv,
is_training, max_seqlen_q, max_seqlen_kv,
qkv_dtype, fused_attention_backend, attn_bias,
attn_scale, dropout_p, set_zero, qkv_layout,
attn_bias_type, attn_mask_type)
out, softmax_aux, rng_state = fused_attn_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
is_training,
max_seqlen_q,
max_seqlen_kv,
qkv_dtype,
fused_attention_backend,
attn_bias,
attn_scale,
dropout_p,
set_zero,
qkv_layout,
attn_bias_type,
attn_mask_type,
)
ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
ctx.max_seqlen_q = max_seqlen_q
......@@ -257,11 +365,27 @@ class FusedAttnFunc(paddle.autograd.PyLayer):
def backward(ctx, d_out):
"""Backward function for FusedAttention with separate Q, K, V tensors"""
q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
dq, dk, dv, *rest = fused_attn_bwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
d_out, softmax_aux, ctx.fused_attention_backend,
ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
dq, dk, dv, *rest = fused_attn_bwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
rng_state,
out,
d_out,
softmax_aux,
ctx.fused_attention_backend,
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
ctx.qkv_dtype,
ctx.attn_scale,
ctx.dropout_p,
ctx.set_zero,
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
)
# if no_bias, return dq, dk, dv
if ctx.attn_bias_type == "no_bias":
return (dq, dk, dv, None, None)
......@@ -306,7 +430,8 @@ class DotProductAttention(paddle.nn.Layer):
backend to use for attention operation.
"""
def __init__(self,
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
num_gqa_groups: Optional[int] = None,
......@@ -314,7 +439,8 @@ class DotProductAttention(paddle.nn.Layer):
attn_mask_type: str = "causal",
attention_type: str = "self",
tp_size: int = 1,
backend: str = 'transformer_engine') -> None:
backend: str = "transformer_engine",
) -> None:
super().__init__()
self.attn_mask_type = attn_mask_type
......@@ -324,7 +450,7 @@ class DotProductAttention(paddle.nn.Layer):
self.hidden_size_per_attention_head = kv_channels
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.tp_size = tp_size
self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups
......@@ -332,14 +458,14 @@ class DotProductAttention(paddle.nn.Layer):
self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
if not self.use_fused_attention and backend == 'transformer_engine':
if not self.use_fused_attention and backend == "transformer_engine":
warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
self.backend = 'paddle'
self.backend = "paddle"
if self.backend != 'transformer_engine':
self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type,
attention_mask_func,
backend=self.backend)
if self.backend != "transformer_engine":
self.scale_mask_softmax = FusedScaleMaskSoftmax(
attn_mask_type, attention_mask_func, backend=self.backend
)
def forward(
self,
......@@ -380,35 +506,53 @@ class DotProductAttention(paddle.nn.Layer):
backend = self.backend
assert (key_layer.shape == value_layer.shape), "Keys and values must have the same shape!"
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
assert (
key_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
if backend == 'transformer_engine':
if backend == "transformer_engine":
max_s_q = query_layer.shape[1]
max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1]
self.fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2],
key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], max_s_q,
max_s_kv, query_layer.shape[-1])
is_backend_avail = (self.fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
])
TE_DType[query_layer.dtype],
TE_DType[query_layer.dtype],
tex.get_nvte_qkv_layout(self.qkv_layout),
AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type],
self.attention_dropout,
query_layer.shape[-2],
key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2],
max_s_q,
max_s_kv,
query_layer.shape[-1],
)
is_backend_avail = self.fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"],
FusedAttnBackend["F16_arbitrary_seqlen"],
]
if is_backend_avail and self.use_fused_attention:
return self._te_forward(query_layer, key_layer, value_layer, attention_mask,
core_attention_bias_type, core_attention_bias, set_zero)
return self._te_forward(
query_layer,
key_layer,
value_layer,
attention_mask,
core_attention_bias_type,
core_attention_bias,
set_zero,
)
warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
backend = 'paddle'
self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type,
attention_mask_func,
backend=backend)
if backend == 'paddle':
backend = "paddle"
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.attn_mask_type, attention_mask_func, backend=backend
)
if backend == "paddle":
if core_attention_bias_type != "no_bias":
warnings.warn("Paddle backend dot product attention does not support bias yet. "
"Bias will be ignored.")
warnings.warn(
"Paddle backend dot product attention does not support bias yet. "
"Bias will be ignored."
)
return self._pd_forward(query_layer, key_layer, value_layer, attention_mask)
raise AttributeError(f"Backend {backend} is not supported.")
......@@ -425,45 +569,76 @@ class DotProductAttention(paddle.nn.Layer):
if self.attention_type == "self":
# self attention - q: [b, s, h, d] kv: None
assert (len(query_layer.shape) == 4 and len(key_layer.shape) == 4
and len(value_layer.shape)
== 4), "q,k,v shape must be [b, s, h, d] for dot product self attention"
assert (
len(query_layer.shape) == 4
and len(key_layer.shape) == 4
and len(value_layer.shape) == 4
), "q,k,v shape must be [b, s, h, d] for dot product self attention"
max_seqlen = query_layer.shape[1]
if self.attn_mask_type == "causal" or attention_mask is None:
cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1],
cu_seqlens = paddle.arange(
0,
(query_layer.shape[0] + 1) * query_layer.shape[1],
step=query_layer.shape[1],
dtype='int32')
dtype="int32",
)
else:
cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
qkv_dtype = TE_DType[query_layer.dtype]
output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens,
cu_seqlens, core_attention_bias, max_seqlen, max_seqlen,
1.0 / self.norm_factor, qkv_dtype,
self.attention_dropout if self.training else 0.0, set_zero,
self.qkv_layout, core_attention_bias_type,
self.attn_mask_type, self.training,
self.fused_attention_backend)
output = FusedAttnFunc.apply(
query_layer,
key_layer,
value_layer,
cu_seqlens,
cu_seqlens,
core_attention_bias,
max_seqlen,
max_seqlen,
1.0 / self.norm_factor,
qkv_dtype,
self.attention_dropout if self.training else 0.0,
set_zero,
self.qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
self.training,
self.fused_attention_backend,
)
elif self.attention_type == "cross":
# cross attention - q: [b, s_q, h, d] k,v: [b, s_kv, h, d]
assert (
len(query_layer.shape) == 4 and len(key_layer.shape) == 4
len(query_layer.shape) == 4
and len(key_layer.shape) == 4
and len(value_layer.shape) == 4
), "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" \
), (
"query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]"
"for dot product cross attention"
assert (attention_mask
is not None), "attention_mask must be provided for cross attention"
)
assert attention_mask is not None, "attention_mask must be provided for cross attention"
max_seqlen_q = query_layer.shape[1]
max_seqlen_kv = key_layer.shape[1]
cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
qkv_dtype = TE_DType[query_layer.dtype]
output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens_q,
cu_seqlens_kv, core_attention_bias, max_seqlen_q,
max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype,
self.attention_dropout if self.training else 0.0, set_zero,
self.qkv_layout, core_attention_bias_type,
self.attn_mask_type, self.training,
self.fused_attention_backend)
output = FusedAttnFunc.apply(
query_layer,
key_layer,
value_layer,
cu_seqlens_q,
cu_seqlens_kv,
core_attention_bias,
max_seqlen_q,
max_seqlen_kv,
1.0 / self.norm_factor,
qkv_dtype,
self.attention_dropout if self.training else 0.0,
set_zero,
self.qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
self.training,
self.fused_attention_backend,
)
else:
raise ValueError("attention_type must be one of ['self', 'cross']")
return output
......@@ -595,8 +770,8 @@ class MultiHeadAttention(paddle.nn.Layer):
tp_group: Optional[dist_group_type] = None,
num_gqa_groups: Optional[int] = None,
fuse_wgrad_accumulation: bool = False,
rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine',
rng_state_name: str = "local_seed",
backend: str = "transformer_engine",
) -> None:
super().__init__()
self.input_layernorm = input_layernorm
......@@ -610,8 +785,9 @@ class MultiHeadAttention(paddle.nn.Layer):
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tp_group, self.tp_size = get_tp_group_and_world_size(
tp_group, enable_tp=set_parallel_mode
)
self.tensor_parallel = self.tp_size > 1
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.hidden_size_per_attention_head = hidden_size // num_attention_heads
......@@ -621,10 +797,12 @@ class MultiHeadAttention(paddle.nn.Layer):
self.backend = backend
self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
assert (self.num_attention_heads % self.num_gqa_groups == 0
self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
assert (
self.num_attention_heads % self.num_gqa_groups == 0
), "The number of attention heads must be divisible by the number of GQA groups!"
assert (self.num_gqa_groups % self.tp_size == 0
assert (
self.num_gqa_groups % self.tp_size == 0
), "The number of GQA groups must be divisible by tensor parallel size!"
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads)
......@@ -776,7 +954,7 @@ class MultiHeadAttention(paddle.nn.Layer):
"""
if self.attn_mask_type != "causal" and attention_mask is not None:
assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"
assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor"
input_dim = len(hidden_states.shape)
if input_dim == 2:
......@@ -806,15 +984,20 @@ class MultiHeadAttention(paddle.nn.Layer):
is_first_microbatch=is_first_microbatch,
)
num_queries_per_key_value = (self.num_attention_heads_per_partition //
self.num_gqa_groups_per_partition)
num_queries_per_key_value = (
self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
)
# [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d]
mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
-1, max_seq_len, (
num_queries_per_key_value +
2), self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
])
mixed_qkv_layer = mixed_qkv_layer.reshape(
shape=[
-1,
max_seq_len,
(num_queries_per_key_value + 2),
self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head,
]
)
# [b, s_q, (h/ng+2), ng, d]
# --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d]
......@@ -826,9 +1009,10 @@ class MultiHeadAttention(paddle.nn.Layer):
# query: -> [b, s, h, d]
# key, value: -> [b, s, ng, d]
query_layer, key_layer, value_layer = (x.reshape(
shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
for x in (query_layer, key_layer, value_layer))
query_layer, key_layer, value_layer = (
x.reshape(shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
for x in (query_layer, key_layer, value_layer)
)
else: # cross attention
mixed_kv_layer = self.key_value(
......@@ -836,9 +1020,14 @@ class MultiHeadAttention(paddle.nn.Layer):
is_first_microbatch=is_first_microbatch,
)
# [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
mixed_kv_layer = mixed_kv_layer.reshape(shape=[
0, 0, 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
])
mixed_kv_layer = mixed_kv_layer.reshape(
shape=[
0,
0,
2 * self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head,
]
)
# [b, s_kv, 2 * ng, head_size]
# --> 2 [b, s_kv, ng, head_size]
......@@ -864,16 +1053,21 @@ class MultiHeadAttention(paddle.nn.Layer):
)
# [b, s, hidden_size] --> [b, s, h, d]
query_layer = query_layer.reshape(shape=[
-1, max_seq_len, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
])
query_layer = query_layer.reshape(
shape=[
-1,
max_seq_len,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
]
)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
if fused_rotary_position_embedding is None:
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, q_pos_emb,
k_pos_emb)
query_layer, key_layer = apply_rotary_pos_emb(
query_layer, key_layer, q_pos_emb, k_pos_emb
)
else:
query_layer, key_layer, _ = fused_rotary_position_embedding(
query_layer,
......@@ -911,10 +1105,12 @@ class MultiHeadAttention(paddle.nn.Layer):
if input_dim == 3:
context_layer = paddle.reshape(
context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]])
context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]]
)
else: # input_dim == 2
context_layer = paddle.reshape(context_layer,
[-1, context_layer.shape[2] * context_layer.shape[3]])
context_layer = paddle.reshape(
context_layer, [-1, context_layer.shape[2] * context_layer.shape[3]]
)
# Output. [b, s, hidden]
attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch)
......
......@@ -12,6 +12,7 @@ from typing import Generator, Dict, Tuple, Union, Any, List, Optional
import numpy as np
import paddle
try:
from paddle.base import core
from paddle.base.framework import _dygraph_tracer
......@@ -52,7 +53,7 @@ def get_workspace() -> paddle.Tensor:
if _cublas_workspace is None:
_cublas_workspace = paddle.empty(
[get_cublas_workspace_size_bytes()],
dtype='uint8',
dtype="uint8",
)
return _cublas_workspace
......@@ -62,7 +63,7 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
def __init__(self) -> None:
super().__init__()
assert 'gpu' in paddle.device.get_device(), "TransformerEngine needs CUDA."
assert "gpu" in paddle.device.get_device(), "TransformerEngine needs CUDA."
self.fp8_initialized = False
self.fp8_enabled = False
self.fp8_calibration = False
......@@ -77,7 +78,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.sequence_parallel = False
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")))
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
)
self.fp8_weight_shapes = []
self.fp8_weight_cache = {}
......@@ -86,11 +88,11 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
tracer = _dygraph_tracer()
if tracer and tracer._amp_level != core.AmpLevel.O0:
# Set activation_dtype to the Paddle AMP dtype if under 'paddle.amp.auto_cast' context
if tracer._amp_dtype == 'float32':
if tracer._amp_dtype == "float32":
self.activation_dtype = paddle.float32
elif tracer._amp_dtype == 'bfloat16':
elif tracer._amp_dtype == "bfloat16":
self.activation_dtype = paddle.bfloat16
elif tracer._amp_dtype == 'float16':
elif tracer._amp_dtype == "float16":
self.activation_dtype = paddle.float16
else:
raise RuntimeError(f"AMP format {tracer._amp_dtype} is not supported.")
......@@ -110,7 +112,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}")
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype
......@@ -125,8 +128,10 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
if self.fp8_enabled or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and global_fp8_state.get_fp8_recipe(
) == self.fp8_meta["recipe"]:
if (
self.fp8_initialized
and global_fp8_state.get_fp8_recipe() == self.fp8_meta["recipe"]
):
return
# Set FP8, recipe, and other FP8 metadata
......@@ -156,8 +161,10 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
if (weight_cast_key in self.fp8_weight_cache
and self.fp8_weight_cache[weight_cast_key].shape == shape):
if (
weight_cast_key in self.fp8_weight_cache
and self.fp8_weight_cache[weight_cast_key].shape == shape
):
return
self.fp8_weight_cache[weight_cast_key] = paddle.empty(
......@@ -231,7 +238,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[
0]
0
]
recompute_buffer_pos_key = FP8RecomputeBuffer.get_buffer_position_key()
if recompute_buffer_pos_key in self.fp8_meta:
del self.fp8_meta[recompute_buffer_pos_key]
......@@ -271,9 +279,10 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.set_fp8_weights()
if self.fp8_enabled and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \
assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
......@@ -283,14 +292,14 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
global_fp8_fwd_buffer.wait()
if self.fp8_meta["recipe"].reduce_amax:
global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta)
amax_and_scale_update(self.fp8_meta,
True,
update_weight_scale_inv=update_weight_scale_inv)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta)
else:
amax_and_scale_update(self.fp8_meta,
True,
update_weight_scale_inv=update_weight_scale_inv)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8_enabled and self.training:
# Setup for amax reduction
......@@ -304,8 +313,11 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase.
if (self.fp8_enabled and self.training
and get_global_fp8_state().is_fp8_recompute_enabled()):
if (
self.fp8_enabled
and self.training
and get_global_fp8_state().is_fp8_recompute_enabled()
):
global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer()
global_recompute_buffer.stash_fp8_meta_tensors(self.fp8_meta)
......@@ -328,11 +340,13 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
@staticmethod
@contextmanager
def prepare_backward(fp8_enabled: bool,
def prepare_backward(
fp8_enabled: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = "") -> Generator[None, None, None]:
name: str = "",
) -> Generator[None, None, None]:
"""Checks and prep for BWD."""
if fp8_enabled:
global_fp8_state = get_global_fp8_state()
......@@ -358,8 +372,9 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size)
@staticmethod
def grad_output_preprocess(ctx, grad_output: paddle.Tensor,
row_parallel_mode: bool) -> Tuple[Union[paddle.Tensor, None], ...]:
def grad_output_preprocess(
ctx, grad_output: paddle.Tensor, row_parallel_mode: bool
) -> Tuple[Union[paddle.Tensor, None], ...]:
"""Utility function for backward.
Returns tuple in order (all optional/None based on training precion/recipe):
R1: gathered `grad_output` in higher precision.
......@@ -447,11 +462,14 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
assert weight_cast_key in self.fp8_weight_cache, \
"TE internal error: fp8 weight buffer is not found"
assert (
weight_cast_key in self.fp8_weight_cache
), "TE internal error: fp8 weight buffer is not found"
out_list.extend([
out_list.extend(
[
self.fp8_weight_cache[weight_cast_key],
self.fp8_weight_cache[weight_transpose_key],
])
]
)
return out_list
......@@ -36,8 +36,15 @@ class _LayerNorm(paddle.autograd.PyLayer):
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.reshape((-1, in_features))
ln_out, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps, TE_DType[inp.dtype],
fwd_ln_sm_margin, zero_centered_gamma)
ln_out, mu, rsigma = layernorm_fwd(
inputmat,
ln_weight,
ln_bias,
eps,
TE_DType[inp.dtype],
fwd_ln_sm_margin,
zero_centered_gamma,
)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
......@@ -52,8 +59,9 @@ class _LayerNorm(paddle.autograd.PyLayer):
def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
inputmat, ln_weight, mu, rsigma = ctx.saved_tensor()
d_ln_out = grad_output.reshape(inputmat.shape)
dxmat, dgamma, dbeta = layernorm_bwd(d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma)
dxmat, dgamma, dbeta = layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None,
dgamma if ctx.requires_dw else None,
......@@ -106,7 +114,7 @@ class LayerNorm(paddle.nn.Layer):
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
zero_centered_gamma: bool = False,
sequence_parallel: bool = False,
backend: str = 'transformer_engine',
backend: str = "transformer_engine",
) -> None:
super().__init__()
self.eps = eps
......@@ -117,8 +125,9 @@ class LayerNorm(paddle.nn.Layer):
self._weight_attr = weight_attr
if not self._weight_attr:
self._weight_attr = paddle.ParamAttr(initializer=Constant(
value=0.0 if self.zero_centered_gamma else 1.0))
self._weight_attr = paddle.ParamAttr(
initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0)
)
self._bias_attr = bias_attr
if self._bias_attr is False:
......@@ -151,8 +160,15 @@ class LayerNorm(paddle.nn.Layer):
def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor:
"""LayerNorm FWD"""
return _LayerNorm.apply(inp, self.weight, self.bias, self.eps, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.zero_centered_gamma)
return _LayerNorm.apply(
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
)
def _pd_forward(
self,
......@@ -161,18 +177,21 @@ class LayerNorm(paddle.nn.Layer):
"""Calls Paddle OP"""
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.")
"Paddle backend does not support LayerNorm with zero-centered scale."
)
return F.layer_norm(x=inp,
return F.layer_norm(
x=inp,
normalized_shape=inp.shape[-1],
weight=self.weight,
bias=self.bias,
epsilon=self.eps)
epsilon=self.eps,
)
def forward(self, *args, **kwargs):
"""forward"""
if self.backend == 'transformer_engine':
if self.backend == "transformer_engine":
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
if self.backend == "paddle":
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} is not supported.")
......@@ -79,14 +79,14 @@ def _apply_normalization_fwd(
}
fwd_normalization_funcs = {
('LayerNorm', True, True): layernorm_fwd,
('LayerNorm', True, False): layernorm_fwd_fp8,
('LayerNorm', False, True): layernorm_fwd,
('LayerNorm', False, False): layernorm_fwd,
('RMSNorm', True, True): rmsnorm_fwd,
('RMSNorm', True, False): rmsnorm_fwd_fp8,
('RMSNorm', False, True): rmsnorm_fwd,
('RMSNorm', False, False): rmsnorm_fwd,
("LayerNorm", True, True): layernorm_fwd,
("LayerNorm", True, False): layernorm_fwd_fp8,
("LayerNorm", False, True): layernorm_fwd,
("LayerNorm", False, False): layernorm_fwd,
("RMSNorm", True, True): rmsnorm_fwd,
("RMSNorm", True, False): rmsnorm_fwd_fp8,
("RMSNorm", False, True): rmsnorm_fwd,
("RMSNorm", False, False): rmsnorm_fwd,
}
if normalization == "LayerNorm":
......@@ -305,13 +305,11 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[paddle.Tensor,
...]) -> Tuple[Union[paddle.Tensor, None], ...]:
with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
ctx.fp8_meta,
ctx.tp_group,
ctx.tp_size,
name="_LayerNormLinear"):
ctx, *grad_outputs: Tuple[paddle.Tensor, ...]
) -> Tuple[Union[paddle.Tensor, None], ...]:
with TransformerEngineBaseLayer.prepare_backward(
ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
):
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
ln_weight,
......@@ -328,12 +326,14 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
grad_output_c,
grad_output_t,
bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0],
ctx.parallel_mode == "row")
) = TransformerEngineBaseLayer.grad_output_preprocess(
ctx, grad_outputs[0], ctx.parallel_mode == "row"
)
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
and not ctx.is_first_microbatch)
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
......@@ -479,14 +479,14 @@ class LayerNormLinear(TransformerEngineBaseLayer):
eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
normalization: str = 'LayerNorm',
normalization: str = "LayerNorm",
return_layernorm_output: bool = False,
zero_centered_gamma: bool = False,
parallel_mode: Optional[str] = None,
sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None,
fuse_wgrad_accumulation: bool = False,
backend: str = 'transformer_engine',
backend: str = "transformer_engine",
) -> None:
super().__init__()
......@@ -494,7 +494,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.out_features = out_features
self.eps = eps
self.normalization = normalization
assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!"
assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
self.return_layernorm_output = return_layernorm_output
self.zero_centered_gamma = zero_centered_gamma
self.backend = backend
......@@ -504,13 +504,14 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self._dtype = self._helper.get_default_dtype()
# Set parallel configs
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=parallel_mode
is not None)
self.tp_group, self.tp_size = get_tp_group_and_world_size(
tp_group, enable_tp=parallel_mode is not None
)
self.tensor_parallel = self.tp_size > 1
self.parallel_mode = parallel_mode
assert (self.parallel_mode
in GemmParallelModes), f"parallel_mode {parallel_mode} not supported"
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
......@@ -524,8 +525,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
# LayerNorm weights
self.ln_weight = self.create_parameter(
shape=[self.in_features],
attr=paddle.ParamAttr(initializer=Constant(
value=0.0 if self.zero_centered_gamma else 1.0)),
attr=paddle.ParamAttr(
initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0)
),
dtype=self._dtype,
is_bias=False,
)
......@@ -548,14 +550,18 @@ class LayerNormLinear(TransformerEngineBaseLayer):
with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major
self.weight = self.create_parameter(
shape=[self.out_features, self.in_features]
if self.backend == 'transformer_engine' else [self.in_features, self.out_features],
shape=(
[self.out_features, self.in_features]
if self.backend == "transformer_engine"
else [self.in_features, self.out_features]
),
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode,
self.backend)
set_weight_tensor_dist_attr(
self.weight, self.tensor_parallel, self.parallel_mode, self.backend
)
self.fp8_weight_shapes.append(self.weight.shape)
# Initialize Linear bias parameter
......@@ -564,8 +570,11 @@ class LayerNormLinear(TransformerEngineBaseLayer):
if self.has_bias:
self.bias = self.create_parameter(
shape=[self.out_features],
attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
initializer=Constant(value=0.0)),
attr=(
self._bias_attr
if not use_default_bias
else paddle.ParamAttr(initializer=Constant(value=0.0))
),
dtype=self._dtype,
is_bias=True,
)
......@@ -656,26 +665,30 @@ class LayerNormLinear(TransformerEngineBaseLayer):
"""Calls Paddle OP"""
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.")
"Paddle backend does not support LayerNorm with zero-centered scale."
)
if is_first_microbatch is not None:
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
"`is_first_microbatch` is not supported for paddle backend and is ignored."
)
if self.normalization == "RMSNorm":
norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps)
norm_out = inp * norm * self.ln_weight
else: # LayerNorm
norm_out = F.layer_norm(x=inp,
norm_out = F.layer_norm(
x=inp,
normalized_shape=inp.shape[-1],
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
epsilon=self.eps,
)
if self.parallel_mode == 'column' and self.tensor_parallel:
if self.parallel_mode == "column" and self.tensor_parallel:
norm_out = identity(norm_out, self.tp_group)
out = F.linear(norm_out, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel:
if self.parallel_mode == "row" and self.tensor_parallel:
out, _ = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out
if self.return_layernorm_output:
......@@ -701,8 +714,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
* during FP8 training, it allows caching of the FP8 versions of
the weights
"""
if self.backend == 'transformer_engine':
if self.backend == "transformer_engine":
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
if self.backend == "paddle":
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} is not supported.")
......@@ -88,7 +88,7 @@ def _mlp_forward(
use_fc1_bias,
fp8_meta,
activation_dtype,
'column' if set_parallel_mode else None,
"column" if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
......@@ -123,7 +123,7 @@ def _mlp_forward(
use_fc2_bias,
fp8_meta,
activation_dtype,
'row' if set_parallel_mode else None,
"row" if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
......@@ -141,7 +141,7 @@ def _mlp_forward(
fp8_calibration,
fp8_meta,
activation_dtype,
'column' if set_parallel_mode else None,
"column" if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
......@@ -166,7 +166,7 @@ def _mlp_forward(
fp8_calibration,
fp8_meta,
activation_dtype,
'row' if set_parallel_mode else None,
"row" if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
......@@ -220,7 +220,13 @@ def _mlp_backward(
fc1_bgrad,
fc2_wgrad,
fc2_bgrad,
) = None, None, None, None, None
) = (
None,
None,
None,
None,
None,
)
if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -252,7 +258,7 @@ def _mlp_backward(
True,
requires_fc2_wgrad,
activation_dtype,
'row' if set_parallel_mode else None,
"row" if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
......@@ -316,7 +322,7 @@ def _mlp_backward(
requires_dgrad,
requires_fc1_wgrad,
activation_dtype,
'column' if set_parallel_mode else None,
"column" if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
......@@ -332,7 +338,7 @@ def _mlp_backward(
True,
requires_fc2_wgrad,
activation_dtype,
'row' if set_parallel_mode else None,
"row" if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
......@@ -353,7 +359,7 @@ def _mlp_backward(
requires_dgrad,
requires_fc1_wgrad,
activation_dtype,
'column' if set_parallel_mode else None,
"column" if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
......@@ -532,13 +538,11 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[paddle.Tensor,
...]) -> Tuple[Union[paddle.Tensor, None], ...]:
with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
ctx.fp8_meta,
ctx.tp_group,
ctx.tp_size,
name="_LayerNormMLP"):
ctx, *grad_outputs: Tuple[paddle.Tensor, ...]
) -> Tuple[Union[paddle.Tensor, None], ...]:
with TransformerEngineBaseLayer.prepare_backward(
ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
):
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
ln_weight,
......@@ -563,8 +567,9 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True)
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
and not ctx.is_first_microbatch)
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
......@@ -731,7 +736,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None,
fuse_wgrad_accumulation: bool = False,
backend: str = 'transformer_engine',
backend: str = "transformer_engine",
) -> None:
super().__init__()
......@@ -750,8 +755,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self._dtype = self._helper.get_default_dtype()
# Set parallel configs
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tp_group, self.tp_size = get_tp_group_and_world_size(
tp_group, enable_tp=set_parallel_mode
)
self.tensor_parallel = self.tp_size > 1
self.set_parallel_mode = set_parallel_mode
self.sequence_parallel = self.tensor_parallel and sequence_parallel
......@@ -766,8 +772,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
# LayerNorm weights
self.ln_weight = self.create_parameter(
shape=[self.hidden_size],
attr=paddle.ParamAttr(initializer=Constant(
value=0.0 if self.zero_centered_gamma else 1.0)),
attr=paddle.ParamAttr(
initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0)
),
dtype=self._dtype,
is_bias=False,
)
......@@ -795,16 +802,18 @@ class LayerNormMLP(TransformerEngineBaseLayer):
with track_rng_state(enable=self.tensor_parallel):
self.fc1_weight = self.create_parameter(
shape=[fc1_output_features, self.hidden_size] if self.backend
== 'transformer_engine' else [self.hidden_size, fc1_output_features],
shape=(
[fc1_output_features, self.hidden_size]
if self.backend == "transformer_engine"
else [self.hidden_size, fc1_output_features]
),
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
set_weight_tensor_dist_attr(self.fc1_weight,
self.tensor_parallel,
parallel_mode='column',
backend=self.backend)
set_weight_tensor_dist_attr(
self.fc1_weight, self.tensor_parallel, parallel_mode="column", backend=self.backend
)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
self.has_bias = self._bias_attr is not False
......@@ -825,16 +834,18 @@ class LayerNormMLP(TransformerEngineBaseLayer):
# FC2 weights
self.fc2_weight = self.create_parameter(
shape=[self.hidden_size, self.size_per_partition] if self.backend
== 'transformer_engine' else [self.size_per_partition, self.hidden_size],
shape=(
[self.hidden_size, self.size_per_partition]
if self.backend == "transformer_engine"
else [self.size_per_partition, self.hidden_size]
),
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
set_weight_tensor_dist_attr(self.fc2_weight,
self.tensor_parallel,
parallel_mode='row',
backend=self.backend)
set_weight_tensor_dist_attr(
self.fc2_weight, self.tensor_parallel, parallel_mode="row", backend=self.backend
)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
if self.has_bias:
......@@ -880,8 +891,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist.
fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = \
fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = (
self.get_fp8_weights_scratchpad(is_first_microbatch)
)
out = _LayerNormMLP.apply(
inp,
......@@ -936,28 +948,33 @@ class LayerNormMLP(TransformerEngineBaseLayer):
"""Calls Paddle OP"""
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.")
"Paddle backend does not support LayerNorm with zero-centered scale."
)
if is_first_microbatch is not None:
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
"`is_first_microbatch` is not supported for paddle backend and is ignored."
)
if self.normalization == "RMSNorm":
norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps)
norm_out = inp * norm * self.ln_weight
else: # LayerNorm
norm_out = F.layer_norm(x=inp,
norm_out = F.layer_norm(
x=inp,
normalized_shape=inp.shape[-1],
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
epsilon=self.eps,
)
if self.set_parallel_mode and self.tensor_parallel:
norm_out = identity(norm_out, self.tp_group)
fc1_out = F.linear(norm_out, self.fc1_weight, self.fc1_bias)
act_func = get_paddle_act_func(self.activation)
act_out = act_func(fc1_out)
out = F.linear(act_out, self.fc2_weight,
self.fc2_bias if self.gemm_bias_fused_add else None)
out = F.linear(
act_out, self.fc2_weight, self.fc2_bias if self.gemm_bias_fused_add else None
)
if self.set_parallel_mode and self.tensor_parallel:
out, _ = allreduce(out, self.tp_group)
out = out + self.fc2_bias if self.fc2_bias is not None else out
......@@ -984,8 +1001,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
* during FP8 training, it allows caching of the FP8 versions of
the weights
"""
if self.backend == 'transformer_engine':
if self.backend == "transformer_engine":
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
if self.backend == "paddle":
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} is not supported.")
......@@ -152,22 +152,26 @@ def _linear_fwd_non_fp8(
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \
paddle.max(paddle.abs(inputmat_total)).item()
fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = paddle.max(
paddle.abs(inputmat_total)
).item()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
paddle.max(paddle.abs(weight)).item()
fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = paddle.max(
paddle.abs(weight)
).item()
fp8_meta["update_amax_and_scale_fwd"] = True
outputs = gemm(weight,
outputs = gemm(
weight,
inputmat_total,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
gelu=(activation == 'gelu'))
gelu=(activation == "gelu"),
)
if activation == 'gelu':
if activation == "gelu":
gelu_out, _, out = outputs
return out, gelu_out
......@@ -382,7 +386,7 @@ def _linear_bwd_non_fp8(
activation_dtype,
get_workspace(),
layout="NN",
gelu=(activation == 'gelu'),
gelu=(activation == "gelu"),
gelu_input=gelu_input,
grad=True,
)
......@@ -527,8 +531,11 @@ class _Linear(paddle.autograd.PyLayer):
inputmat_t = None
if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if (not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled
and not sequence_parallel):
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
and not sequence_parallel
):
inputmat, inputmat_t = cast_transpose(
inputmat,
fp8_meta["scaling_fwd"],
......@@ -599,11 +606,9 @@ class _Linear(paddle.autograd.PyLayer):
@staticmethod
def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
ctx.fp8_meta,
ctx.tp_group,
ctx.tp_size,
name="_Linear"):
with TransformerEngineBaseLayer.prepare_backward(
ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
):
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
......@@ -618,11 +623,13 @@ class _Linear(paddle.autograd.PyLayer):
grad_output_c,
grad_output_t,
bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output,
ctx.parallel_mode == "row")
) = TransformerEngineBaseLayer.grad_output_preprocess(
ctx, grad_output, ctx.parallel_mode == "row"
)
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
and not ctx.is_first_microbatch)
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
......@@ -730,7 +737,7 @@ class Linear(TransformerEngineBaseLayer):
sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None,
fuse_wgrad_accumulation: bool = False,
backend: str = 'transformer_engine',
backend: str = "transformer_engine",
) -> None:
super().__init__()
self.in_features = in_features
......@@ -741,13 +748,14 @@ class Linear(TransformerEngineBaseLayer):
self._dtype = self._helper.get_default_dtype()
# Set parallel configs
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=parallel_mode
is not None)
self.tp_group, self.tp_size = get_tp_group_and_world_size(
tp_group, enable_tp=parallel_mode is not None
)
self.tensor_parallel = self.tp_size > 1
self.parallel_mode = parallel_mode
assert (self.parallel_mode
in GemmParallelModes), f"parallel_mode {parallel_mode} not supported"
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
......@@ -762,14 +770,18 @@ class Linear(TransformerEngineBaseLayer):
with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major
self.weight = self.create_parameter(
shape=[self.out_features, self.in_features]
if self.backend == 'transformer_engine' else [self.in_features, self.out_features],
shape=(
[self.out_features, self.in_features]
if self.backend == "transformer_engine"
else [self.in_features, self.out_features]
),
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode,
self.backend)
set_weight_tensor_dist_attr(
self.weight, self.tensor_parallel, self.parallel_mode, self.backend
)
# Initialize bias parameter
self.has_bias = self._bias_attr is not False
......@@ -777,8 +789,11 @@ class Linear(TransformerEngineBaseLayer):
if self.has_bias:
self.bias = self.create_parameter(
shape=[self.out_features],
attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
initializer=Constant(value=0.0)),
attr=(
self._bias_attr
if not use_default_bias
else paddle.ParamAttr(initializer=Constant(value=0.0))
),
dtype=self._dtype,
is_bias=True,
)
......@@ -849,11 +864,12 @@ class Linear(TransformerEngineBaseLayer):
"""Calls Paddle OP"""
if is_first_microbatch is not None:
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
if self.parallel_mode == 'column' and self.tensor_parallel:
"`is_first_microbatch` is not supported for paddle backend and is ignored."
)
if self.parallel_mode == "column" and self.tensor_parallel:
inp = identity(inp, self.tp_group)
out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel:
if self.parallel_mode == "row" and self.tensor_parallel:
out, _ = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out
return out
......@@ -877,8 +893,8 @@ class Linear(TransformerEngineBaseLayer):
* during FP8 training, it allows caching of the FP8 versions of
the weights
"""
if self.backend == 'transformer_engine':
if self.backend == "transformer_engine":
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
if self.backend == "paddle":
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} is not supported.")
......@@ -33,8 +33,14 @@ class _RMSNorm(paddle.autograd.PyLayer):
assert inp.shape[-1] == in_features, "RMSNorm not possible"
inputmat = inp.reshape((-1, in_features))
rmsnorm_out, rsigma = rmsnorm_fwd(inputmat, rmsnorm_weight, eps, TE_DType[inp.dtype],
fwd_rmsnorm_sm_margin, zero_centered_gamma)
rmsnorm_out, rsigma = rmsnorm_fwd(
inputmat,
rmsnorm_weight,
eps,
TE_DType[inp.dtype],
fwd_rmsnorm_sm_margin,
zero_centered_gamma,
)
ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma)
ctx.inp_shape = inp.shape
......@@ -49,8 +55,14 @@ class _RMSNorm(paddle.autograd.PyLayer):
def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
inputmat, rmsnorm_weight, rsigma = ctx.saved_tensor()
d_rmsnorm_out = grad_output.reshape(inputmat.shape)
dxmat, dgamma = rmsnorm_bwd(d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight,
ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma)
dxmat, dgamma = rmsnorm_bwd(
d_rmsnorm_out,
inputmat,
rsigma,
rmsnorm_weight,
ctx.bwd_rmsnorm_sm_margin,
ctx.zero_centered_gamma,
)
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None,
dgamma if ctx.requires_dw else None,
......@@ -149,7 +161,8 @@ class RMSNorm(paddle.nn.Layer):
) -> paddle.Tensor:
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support RMSNorm with zero_centered_gamma.")
"Paddle backend does not support RMSNorm with zero_centered_gamma."
)
norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps)
y = inp * norm * self.weight
return y
......
......@@ -32,8 +32,9 @@ _default_causal_mask = {}
def _get_default_causal_mask(seqlen: int) -> paddle.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if seqlen not in _default_causal_mask:
_default_causal_mask[seqlen] = paddle.triu(paddle.ones((seqlen, seqlen)),
diagonal=1).cast('bool')
_default_causal_mask[seqlen] = paddle.triu(paddle.ones((seqlen, seqlen)), diagonal=1).cast(
"bool"
)
return _default_causal_mask[seqlen]
......@@ -58,8 +59,9 @@ class ScaledUpperTriangMaskedSoftmax(paddle.autograd.PyLayer):
def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
"""ScaledUpperTriangMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensor()
input_grads = scaled_upper_triang_masked_softmax_backward(output_grads, softmax_results,
scale_t[0])
input_grads = scaled_upper_triang_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
......@@ -140,7 +142,7 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer):
attn_mask_type: str,
mask_func: Callable,
softmax_in_fp32: bool = True,
backend: str = 'transformer_engine',
backend: str = "transformer_engine",
) -> None:
super().__init__()
self.attn_mask_type = attn_mask_type
......@@ -162,16 +164,17 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer):
self.input_is_bf16 = inp.dtype == paddle.bfloat16
self.input_in_16bit_float = self.input_is_fp16 or self.input_is_bf16
assert (scale is None or self.softmax_in_fp32), "softmax should be in fp32 when scaled"
assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled"
if self.backend == 'transformer_engine' and not self.is_kernel_available(*inp.shape):
if self.backend == "transformer_engine" and not self.is_kernel_available(*inp.shape):
warnings.warn(
"fused kernel is not available for this input shape, fall back to paddle backend")
self.backend = 'paddle'
"fused kernel is not available for this input shape, fall back to paddle backend"
)
self.backend = "paddle"
if self.backend == 'transformer_engine':
if self.backend == "transformer_engine":
return self._te_forward(inp, mask, scale)
if self.backend == 'paddle':
if self.backend == "paddle":
return self._pd_forward(inp, mask, scale)
raise AttributeError(f"Backend {self.backend} is not supported.")
......@@ -179,7 +182,8 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer):
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * h
if (self.scaled_masked_softmax_fusion # user want to fuse
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_16bit_float # input must be fp16
and 16 < s_kv <= 4096 # s_kv must be 16 ~ 2048
and s_q % 4 == 0 # s_q must be a multiple of 4
......@@ -196,10 +200,9 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer):
return True
return False
def _te_forward(self,
inp: paddle.Tensor,
mask: paddle.Tensor,
scale: Optional[float] = None) -> paddle.Tensor:
def _te_forward(
self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None
) -> paddle.Tensor:
"""Fused masked softmax kernel"""
b, h, s_q, s_kv = inp.size()
scale = 1.0 if scale is None else scale
......@@ -216,13 +219,12 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer):
return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale)
def _pd_forward(self,
inp: paddle.Tensor,
mask: paddle.Tensor,
scale: Optional[float] = None) -> paddle.Tensor:
def _pd_forward(
self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None
) -> paddle.Tensor:
"""Call Paddle OP"""
if self.input_in_16bit_float and self.softmax_in_fp32:
inp = paddle.cast(inp, 'float32')
inp = paddle.cast(inp, "float32")
if scale is not None:
inp = inp * scale
......@@ -235,9 +237,9 @@ class FusedScaleMaskSoftmax(paddle.nn.Layer):
if self.input_in_16bit_float and self.softmax_in_fp32:
if self.input_is_fp16:
probs = paddle.cast(probs, 'float16')
probs = paddle.cast(probs, "float16")
else:
probs = paddle.cast(probs, 'bfloat16')
probs = paddle.cast(probs, "bfloat16")
return probs
......
......@@ -112,7 +112,8 @@ class TransformerLayer(paddle.nn.Layer):
"""
def __init__(self,
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
......@@ -130,14 +131,15 @@ class TransformerLayer(paddle.nn.Layer):
layer_type: str = "encoder",
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False,
activation: str = 'gelu',
activation: str = "gelu",
set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None,
fuse_wgrad_accumulation: bool = False,
attention_dropout_rng_state_name: str = 'local_seed',
hidden_dropout_rng_state_name: str = 'global_seed',
backend: str = 'transformer_engine') -> None:
attention_dropout_rng_state_name: str = "local_seed",
hidden_dropout_rng_state_name: str = "global_seed",
backend: str = "transformer_engine",
) -> None:
super().__init__()
params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
......@@ -146,19 +148,23 @@ class TransformerLayer(paddle.nn.Layer):
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.self_attn_mask_type = self_attn_mask_type
self.set_parallel_mode = set_parallel_mode
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tp_group, self.tp_size = get_tp_group_and_world_size(
tp_group, enable_tp=set_parallel_mode
)
self.tensor_parallel = self.tp_size > 1
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name
# SP needs local seed for hidden dropout
if self.sequence_parallel and self.hidden_dropout_rng_state_name == 'global_seed':
warnings.warn("RNG state for hidden dropout needs to be different across TP ranks. "
"Forcing hidden_dropout_rng_state_name to 'local_seed'")
self.hidden_dropout_rng_state_name = 'local_seed'
if self.sequence_parallel and self.hidden_dropout_rng_state_name == "global_seed":
warnings.warn(
"RNG state for hidden dropout needs to be different across TP ranks. "
"Forcing hidden_dropout_rng_state_name to 'local_seed'"
)
self.hidden_dropout_rng_state_name = "local_seed"
assert (self_attn_mask_type
in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"
attention_args = (
......@@ -176,7 +182,7 @@ class TransformerLayer(paddle.nn.Layer):
"zero_centered_gamma": zero_centered_gamma,
"set_parallel_mode": set_parallel_mode,
"sequence_parallel": self.sequence_parallel,
'max_sequence_length': max_sequence_length,
"max_sequence_length": max_sequence_length,
"tp_group": tp_group,
"num_gqa_groups": num_gqa_groups,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
......@@ -295,10 +301,12 @@ class TransformerLayer(paddle.nn.Layer):
"""
if self.self_attn_mask_type != "causal" and attention_mask is not None:
assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"
assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor"
assert core_attention_bias_type in ['no_bias'], f"Only no_bias is supported currently, " \
assert core_attention_bias_type in ["no_bias"], (
"Only no_bias is supported currently, "
f"but receive core_attention_bias_type = {core_attention_bias_type}"
)
# Self attention.
self_attention_outputs = self.self_attention(
......@@ -340,8 +348,9 @@ class TransformerLayer(paddle.nn.Layer):
attention_output = inter_attention_outputs
residual = bda_output
with track_rng_state(enable=self.tensor_parallel,
name=self.hidden_dropout_rng_state_name):
with track_rng_state(
enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name
):
bda_output = self.fused_dropout_add2(attention_output, residual)
# MLP.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment