Unverified Commit 4daea528 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

Add more TE operators for Paddle (#262)



* Add cast_transpose

Add gelu, gelu_fp8

Add cast_transpose_bgrad_dgelu

Add layernorm_fwd and layernorm_fwd_fp8

Add layernorm_bwd
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix missing header
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5c58beaa
......@@ -10,7 +10,20 @@ from utils import assert_allclose, create_fp8_meta
import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
from transformer_engine.paddle.cpp_extensions import cast_to_fp8, cast_from_fp8, gemm, fp8_gemm
from transformer_engine.paddle.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
gemm,
fp8_gemm,
transpose,
cast_transpose,
te_gelu,
gelu_fp8,
dgelu_cast_transpose_bgrad_fp8,
layernorm_fwd_fp8,
layernorm_fwd,
layernorm_bwd,
)
from transformer_engine.paddle.fp8 import is_fp8_available
paddle.seed(10)
......@@ -36,6 +49,145 @@ def test_quantize_dequantize():
assert_allclose(a, b, rtol=5e-2, atol=5e-2)
class TestTranspose:
"""
Test transpose operators
"""
@staticmethod
def test_transpose_bf16():
"""
Test BF16 transpose
"""
a = paddle.rand(shape=(16, 32), dtype='bfloat16')
a_transposed = transpose(a, otype=tex.DType.kBFloat16)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_transpose_fp8(fp8_dtype):
"""
Test FP8 transpose
"""
min_val = -8
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1)
a_fp8 = cast_to_fp8(a, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype)
a_transposed = cast_from_fp8(a_fp8_transposed,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_cast_transpose(fp8_dtype):
"""
Test cast_transpose
"""
min_val = -8
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1)
a_fp8_casted, a_fp8_transposed = cast_transpose(a,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype)
a_transposed = cast_from_fp8(a_fp8_transposed,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
a_casted = cast_from_fp8(a_fp8_casted,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T)
class TestActivation:
"""
Test activation operators
"""
@staticmethod
def test_gelu_bf16():
"""
Test BF16 GELU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
gelu_out = te_gelu(a, otype=tex.DType.kBFloat16)
gelu_ref = paddle.nn.GELU()(a)
assert_allclose(gelu_out, gelu_ref, rtol=1e-2)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_fp8(fp8_dtype):
"""
Test FP8 GELU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1)
gelu_out_fp8 = gelu_fp8(a, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
gelu_out = cast_from_fp8(gelu_out_fp8,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
gelu_ref = paddle.nn.GELU()(a)
assert_allclose(gelu_out, gelu_ref, rtol=0.1, atol=0.01)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_bwd_fp8(fp8_dtype):
"""
Test FP8 GELU Backward
"""
# y = GELU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
x.stop_gradient = False
y = paddle.nn.GELU()(x)
y_grad = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1)
x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(
y_grad, x, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
x_grad = cast_from_fp8(x_grad_fp8,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
x_grad_t = cast_from_fp8(x_grad_t_fp8,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01)
class TestGemm:
"""
Tests for gemm(cuBLASLt) operator
......@@ -114,3 +266,107 @@ class TestGemm:
tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype, out_dtype, workspace)
assert_allclose(actual_out, ref_out)
class TestLayerNorm:
"""
Test layernorm operators
"""
@staticmethod
def calc_fwd_ref(x, eps, gamma, beta):
"""
Calculate reference using paddle layer_norm op
"""
y = paddle.nn.functional.layer_norm(x=x,
normalized_shape=x.shape[1:],
weight=gamma,
bias=beta,
epsilon=eps)
mean = paddle.mean(x, axis=-1)
var = paddle.var(x, axis=-1)
inv_var = paddle.sqrt(1. / var)
return y, mean, inv_var
@staticmethod
def calc_bwd_ref(x, eps, gamma, beta, dy):
"""
Calculate reference using paddle layer_norm op
"""
x.stop_gradient = False
gamma.stop_gradient = False
beta.stop_gradient = False
y = paddle.nn.functional.layer_norm(x=x,
normalized_shape=x.shape[1:],
weight=gamma,
bias=beta,
epsilon=eps)
paddle.autograd.backward([y], [dy], True)
return x.grad, gamma.grad, beta.grad
def test_layernorm_fwd(self):
"""
Test BF16 LayerNorm Forward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
beta = paddle.uniform(shape=(H,), dtype='bfloat16')
y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)
y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta)
assert_allclose(y, y_ref, rtol=1e-5, atol=1e-5)
assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3)
assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2)
@staticmethod
def test_layernorm_fwd_fp8():
"""
Test FP8 LayerNorm Forward
"""
fp8_dtype = tex.DType.kFloat8E4M3
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='float32')
gamma = paddle.uniform(shape=(H,), dtype='float32')
beta = paddle.uniform(shape=(H,), dtype='float32')
fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1)
y_ref, mu_ref, rsigma_ref = layernorm_fwd(x, gamma, beta, eps, tex.DType.kFloat32)
y_fp8, mu, rsigma = layernorm_fwd_fp8(x, gamma, beta, eps, fp8_meta, fp8_tensor, fp8_dtype)
y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)
assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
assert_allclose(mu, mu_ref)
assert_allclose(rsigma, rsigma_ref)
def test_layernorm_bwd(self):
"""
Test BF16 LayerNorm Backward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
beta = paddle.uniform(shape=(H,), dtype='bfloat16')
dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy)
_, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)
dx, dgamma, dbeta = layernorm_bwd(dy, x, mu, rsigma, gamma)
assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5)
......@@ -208,3 +208,133 @@ def cast_from_fp8(
int(itype),
int(otype),
)
def transpose(
inp: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""Transpose input"""
return tex.te_transpose(
inp,
int(otype),
)
def cast_transpose(
inp: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]:
"""Cast + Transpose with FP8 output"""
cast_out, transpose_out, _, _ = tex.te_cast_transpose(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
int(fp8_tensor),
int(otype),
)
return cast_out, transpose_out
def te_gelu(
inp: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""Non FP8 GELU"""
return tex.te_gelu(
inp,
int(otype),
)
def gelu_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> paddle.Tensor:
"""GELU + FP8 cast"""
out, _, _ = tex.te_gelu_fp8(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
int(fp8_tensor),
int(otype),
)
return out
def dgelu_cast_transpose_bgrad_fp8(
grad_output: paddle.Tensor,
gelu_input: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""
Fused dgelu + cast / transpose / reduce the result of
the GELU backward along the first dimension
"""
cast_dgelu, transpose_dgelu, dbias, _, _ = tex.te_cast_transpose_bgrad_dgelu(
grad_output,
gelu_input,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
int(fp8_tensor),
int(otype),
)
return cast_dgelu, transpose_dgelu, dbias
def layernorm_fwd_fp8(
inp: paddle.Tensor,
weight: paddle.Tensor,
bias: paddle.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""LayerNorm with FP8 output"""
out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(inp, weight, bias, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps,
int(fp8_tensor), int(otype), sm_margin,
zero_centered_gamma)
return out, mu, rsigma
def layernorm_fwd(
inp: paddle.Tensor,
weight: paddle.Tensor,
bias: paddle.Tensor,
eps: float,
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 LayerNorm forward"""
return tex.te_layernorm_fwd(inp, weight, bias, eps, int(otype), sm_margin, zero_centered_gamma)
def layernorm_bwd(
dz: paddle.Tensor,
x: paddle.Tensor,
mu: paddle.Tensor,
rsigma: paddle.Tensor,
gamma: paddle.Tensor,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 LayerNorm backward"""
return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma)
......@@ -13,6 +13,10 @@ TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, c
return TensorWrapper(data_ptr, shape, type);
}
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) {
return TensorWrapper(data_ptr, shape, type);
}
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
void *amax_ptr, void *scale_ptr, void *scale_inv_ptr) {
return TensorWrapper(data_ptr, shape, type, reinterpret_cast<float *>(amax_ptr),
......@@ -20,10 +24,33 @@ TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, c
reinterpret_cast<float *>(scale_inv_ptr));
}
TensorWrapper MakeNvteTensor(paddle::Tensor &tensor) { // NOLINT
return MakeNvteTensor(tensor.data(), GetShapeArray(tensor), Paddle2NvteDType(tensor.dtype()));
}
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor) {
return MakeNvteTensor(const_cast<void *>(tensor.data()), GetShapeArray(tensor),
Paddle2NvteDType(tensor.dtype()));
}
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
bool init_to_zeros) {
auto size = shape.ndim;
if (size == 2 && init_to_zeros) {
return paddle::zeros(
{static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
Nvte2PaddleDType(type), place);
} else if (size == 2) {
return paddle::empty(
{static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
Nvte2PaddleDType(type), place);
} else if (size == 1 && init_to_zeros) {
return paddle::zeros({static_cast<int64_t>(shape.data[0])}, Nvte2PaddleDType(type), place);
} else if (size == 1) {
return paddle::empty({static_cast<int64_t>(shape.data[0])}, Nvte2PaddleDType(type), place);
}
NVTE_CHECK(false, "Should never reach here! func: AllocateSpace");
}
} // namespace paddle_ext
} // namespace transformer_engine
......@@ -6,10 +6,13 @@
#pragma once
#include <cublasLt.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <vector>
#include "paddle/extension.h"
......@@ -88,6 +91,9 @@ inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
return shapes;
}
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
bool init_to_zeros = 0);
// DType Utils
inline paddle::DataType Nvte2PaddleDType(DType t) {
switch (t) {
......@@ -136,10 +142,45 @@ inline DType Int2NvteDType(int64_t dtype) {
}
}
// CUDA Utils
class cudaDevicePropertiesManager {
public:
static cudaDevicePropertiesManager &Instance() {
static thread_local cudaDevicePropertiesManager instance;
return instance;
}
int GetMultiProcessorCount() {
if (!prop_queried_) {
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
cudaGetDeviceProperties(&prop_, device_id);
prop_queried_ = true;
}
return prop_.multiProcessorCount;
}
int GetMajor() {
if (!prop_queried_) {
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
cudaGetDeviceProperties(&prop_, device_id);
prop_queried_ = true;
}
return prop_.major;
}
private:
bool prop_queried_ = false;
cudaDeviceProp prop_;
};
// NVTE Tensor Utils
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
void *amax_ptr, void *scale_ptr, void *scale_inv_ptr);
TensorWrapper MakeNvteTensor(paddle::Tensor &tensor); // NOLINT
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor);
} // namespace paddle_ext
......
......@@ -42,6 +42,53 @@ std::vector<paddle::Tensor> cast_from_fp8(const paddle::Tensor &input,
return {output};
}
std::vector<paddle::Tensor> te_transpose(const paddle::Tensor &input, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty({input.shape()[1], input.shape()[0]}, input.dtype(), input.place());
auto input_cu = MakeNvteTensor(const_cast<void *>(input.data()), {M, N}, Int2NvteDType(otype));
auto output_cu = MakeNvteTensor(output.data(), {N, M}, Int2NvteDType(otype));
nvte_transpose(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_cast_transpose(const paddle::Tensor &input,
const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto input_cast =
paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_transpose = paddle::empty({input.shape()[1], input.shape()[0]},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
void *amax_data = GetDataPtr<float>(amax, index);
void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index));
void *scale_inv_data = GetDataPtr<float>(scale_inv, index);
auto output_cast_cu = MakeNvteTensor(input_cast.data(), {M, N}, Int2NvteDType(otype), amax_data,
scale_data, scale_inv_data);
auto output_transpose_cu = MakeNvteTensor(input_transpose.data(), {N, M}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
input.stream());
return {input_cast, input_transpose};
}
void te_gemm(const paddle::Tensor &A, const paddle::optional<paddle::Tensor> &A_scale_inverse,
const paddle::Tensor &B, const paddle::optional<paddle::Tensor> &B_scale_inverse,
const paddle::optional<paddle::Tensor> &bias, paddle::Tensor &D, // NOLINT
......@@ -77,6 +124,219 @@ void te_gemm(const paddle::Tensor &A, const paddle::optional<paddle::Tensor> &A_
math_sm_count, A.stream());
}
std::vector<paddle::Tensor> te_gelu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto output = paddle::empty_like(input, Nvte2PaddleDType(DType::kByte), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(
output.data(), GetShapeArray(input), Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
nvte_gelu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_gelu(const paddle::Tensor &input, int64_t otype) {
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(input), Int2NvteDType(otype));
nvte_gelu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output,
const paddle::Tensor &gelu_input,
const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(grad_output);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
// DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias =
paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place());
auto dgelu =
paddle::empty_like(grad_output, Nvte2PaddleDType(DType::kByte), grad_output.place());
auto dgelu_transpose = paddle::empty({grad_output.shape()[1], grad_output.shape()[0]},
Nvte2PaddleDType(DType::kByte), grad_output.place());
void *amax_data = GetDataPtr<float>(amax, index);
void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index));
void *scale_inv_data = GetDataPtr<float>(scale_inv, index);
TensorWrapper workspace;
auto gelu_input_cu = MakeNvteTensor(gelu_input);
auto input_cu = MakeNvteTensor(grad_output);
auto cast_output_cu = MakeNvteTensor(dgelu.data(), {M, N}, Int2NvteDType(otype), amax_data,
scale_data, scale_inv_data);
auto transposed_output_cu = MakeNvteTensor(dgelu_transpose.data(), {N, M}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
auto dbias_cu = MakeNvteTensor(grad_bias);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(), workspace.data(),
grad_output.stream());
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(), workspace.data(),
grad_output.stream());
return {dgelu, dgelu_transpose, grad_bias};
}
std::vector<paddle::Tensor> te_layernorm_fwd_fp8(const paddle::Tensor &input,
const paddle::Tensor &weight,
const paddle::Tensor &bias,
const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
float eps, int64_t index, int64_t otype,
int64_t sm_margin, bool zero_centered_gamma) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto mu = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto beta_cu = MakeNvteTensor(bias);
auto z_cu = MakeNvteTensor(
ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
auto mu_cu = MakeNvteTensor(mu);
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace, barrier;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true);
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
std::vector<paddle::Tensor> te_layernorm_fwd(const paddle::Tensor &input,
const paddle::Tensor &weight,
const paddle::Tensor &bias, float eps, int64_t otype,
int64_t sm_margin, bool zero_centered_gamma) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto mu = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto beta_cu = MakeNvteTensor(bias);
auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype));
auto mu_cu = MakeNvteTensor(mu);
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace, barrier;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true);
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(),
rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x,
const paddle::Tensor &mu, const paddle::Tensor &rsigma,
const paddle::Tensor &gamma, int64_t sm_margin,
bool zero_centered_gamma) {
auto dx = paddle::empty_like(x, x.dtype(), x.place());
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
TensorWrapper workspace, barrier, dgamma_part, dbeta_part;
auto dz_cu = MakeNvteTensor(dz);
auto x_cu = MakeNvteTensor(x);
auto mu_cu = MakeNvteTensor(mu);
auto rsigma_cu = MakeNvteTensor(rsigma);
auto gamma_cu = MakeNvteTensor(gamma);
auto dx_cu = MakeNvteTensor(dx);
auto dgamma_cu = MakeNvteTensor(dgamma);
auto dbeta_cu = MakeNvteTensor(dbeta);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates tensors with the required config.
const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(),
dz.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
// Alloc space for Tensors.
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place());
auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true);
auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place());
auto dbeta_part_data = AllocateSpace(dbeta_part.shape(), dbeta_part.dtype(), x.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype());
dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype());
dbeta_part = MakeNvteTensor(dbeta_part_data.data(), dbeta_part.shape(), dbeta_part.dtype());
// Actual call to bwd kernel.
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(),
dz.stream(), num_sm - sm_margin, workspace.data(), barrier.data());
return {dx, dgamma, dbeta};
}
} // namespace paddle_ext
} // namespace transformer_engine
......@@ -109,3 +369,56 @@ PD_BUILD_OP(cast_from_fp8)
.Outputs({"Output"})
.Attrs({"index: int64_t", "itype: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_from_fp8));
PD_BUILD_OP(te_transpose)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_transpose));
PD_BUILD_OP(te_cast_transpose)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"CastedOutput", "TransposedOutput", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose));
PD_BUILD_OP(te_gelu_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu_fp8));
PD_BUILD_OP(te_gelu)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu));
PD_BUILD_OP(te_cast_transpose_bgrad_dgelu)
.Inputs({"GradOutput", "GeluInput", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"CastedDgelu", "TransposedDgelu", "Dbias", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad_dgelu));
PD_BUILD_OP(te_layernorm_fwd_fp8)
.Inputs({"Input", "Weight", "Bias", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Mu", "Rsigma", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t",
"zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd_fp8));
PD_BUILD_OP(te_layernorm_fwd)
.Inputs({"Input", "Weight", "Bias"})
.Outputs({"Output", "Mu", "Rsigma"})
.Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd));
PD_BUILD_OP(te_layernorm_bwd)
.Inputs({"Dz", "X", "Mu", "Rsigma", "Gamma"})
.Outputs({"Dx", "Dgamma", "Dbeta"})
.Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_bwd));
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment