Unverified Commit c67bb2fc authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Adding other activation types to LayerNormMLP (#265)



* Added ReLU and GLU variants to common
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* pyTorch changes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* PyTorch C++ lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix storage errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Compute bgrad
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix numerical tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX export tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent df6f347f
...@@ -57,6 +57,7 @@ batch_sizes = [1, 2] ...@@ -57,6 +57,7 @@ batch_sizes = [1, 2]
all_boolean = [True, False] all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
def get_causal_attn_mask(sq: int) -> torch.Tensor: def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
...@@ -334,13 +335,37 @@ class TorchMHA(nn.Module): ...@@ -334,13 +335,37 @@ class TorchMHA(nn.Module):
def forward(self, x, attn_mask=None): def forward(self, x, attn_mask=None):
return self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False) return self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False)
_supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'gelu' : nn.GELU(approximate="tanh"),
'reglu' : nn.ReLU(),
'relu' : nn.ReLU(),
'swiglu' : nn.SiLU()}
class TorchGLU(nn.Module):
def __init__(self, activation: str):
super().__init__()
self.act = _supported_act[activation]
def forward(self, x):
shape = x.size(-1)
a = x[..., :shape // 2]
b = x[..., (shape // 2):]
a = self.act(a)
return a * b
class TorchLayerNormMLP(nn.Module): class TorchLayerNormMLP(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, eps: float = 1e-5): def __init__(self, hidden_size: int, ffn_hidden_size: int,
eps: float = 1e-5, activation = 'gelu'):
super().__init__() super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps) self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.fc1 = nn.Linear(hidden_size, ffn_hidden_size) if 'glu' in activation:
self.gelu = nn.GELU(approximate="tanh") fc1_output_features = 2 * ffn_hidden_size
self.gelu = TorchGLU(activation)
else:
fc1_output_features = ffn_hidden_size
self.gelu = _supported_act[activation]
self.fc1 = nn.Linear(hidden_size, fc1_output_features)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size) self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
def forward(self, x): def forward(self, x):
...@@ -856,13 +881,15 @@ def test_layernorm_linear_accuracy(dtype, bs, model): ...@@ -856,13 +881,15 @@ def test_layernorm_linear_accuracy(dtype, bs, model):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_layernorm_mlp_accuracy(dtype, bs, model): @pytest.mark.parametrize("activation", all_activations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation):
config = model_configs[model] config = model_configs[model]
te_ln_mlp = ( te_ln_mlp = (
LayerNormMLP( LayerNormMLP(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
activation=activation,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -873,6 +900,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model): ...@@ -873,6 +900,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model):
TorchLayerNormMLP( TorchLayerNormMLP(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
activation=activation,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
......
...@@ -20,7 +20,6 @@ To run many repetitive tests use pytest-loop: ...@@ -20,7 +20,6 @@ To run many repetitive tests use pytest-loop:
For reproducability use: torch.manual_seed(0) For reproducability use: torch.manual_seed(0)
""" """
import os import os
import tempfile import tempfile
import pytest import pytest
...@@ -33,7 +32,7 @@ from typing import Optional, Union, Tuple, List ...@@ -33,7 +32,7 @@ from typing import Optional, Union, Tuple, List
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, fp8_gelu, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs import transformer_engine.pytorch.softmax as softmax_defs
...@@ -403,7 +402,7 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa ...@@ -403,7 +402,7 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
self.fake_bf16_io = fake_bf16_io self.fake_bf16_io = fake_bf16_io
def forward(self, inp): def forward(self, inp):
ret = fp8_gelu( ret = gelu(
inp, inp,
self.meta, self.meta,
self.fp8_tensor, self.fp8_tensor,
......
...@@ -94,6 +94,7 @@ batch_sizes = [1, 2] ...@@ -94,6 +94,7 @@ batch_sizes = [1, 2]
all_boolean = [True, False] all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
def _disable_wgrads(block): def _disable_wgrads(block):
for p in block.parameters(): for p in block.parameters():
...@@ -368,7 +369,8 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad): ...@@ -368,7 +369,8 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad): @pytest.mark.parametrize("activation", all_activations)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -386,6 +388,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen ...@@ -386,6 +388,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
activation=activation,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -400,7 +403,8 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen ...@@ -400,7 +403,8 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, bias): @pytest.mark.parametrize("activation", all_activations)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, bias, activation):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -425,6 +429,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -425,6 +429,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
bias=bias, bias=bias,
activation=activation,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
......
...@@ -12,6 +12,8 @@ list(APPEND transformer_engine_SOURCES ...@@ -12,6 +12,8 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
activation/gelu.cu activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp fused_attn/fused_attn.cpp
......
...@@ -16,30 +16,46 @@ ...@@ -16,30 +16,46 @@
namespace transformer_engine { namespace transformer_engine {
namespace detail { void gelu(const Tensor &input,
Tensor *output,
struct GELUParam {}; cudaStream_t stream) {
CheckInputTensor(input, "gelu_input");
__device__ inline fp32 gelu(fp32 value, const GELUParam &) { CheckOutputTensor(*output, "gelu_output");
return value * (0.5F + 0.5F * tanhf(value * (0.79788456F + 0.03567741F * value * value))); NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
} const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Empty, gelu<fp32, fp32> >(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
Empty(),
stream);
); // NOLINT(*)
); // NOLINT(*)
} }
void gelu_cast(const Tensor &input, void dgelu(const Tensor &grad,
const Tensor &input,
Tensor *output, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(input, "gelu_input"); CheckInputTensor(input, "dgelu_input");
CheckOutputTensor(*output, "gelu_output"); CheckInputTensor(grad, "dgelu_input_grad");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); CheckOutputTensor(*output, "dgelu_output");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = input.data.shape[1] * input.data.shape[0]; NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType); constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::GELUParam, detail::gelu>( VectorizedUnaryGradKernelLauncher<nvec, Empty, dgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr), reinterpret_cast<const fp32*>(output->scale.dptr),
...@@ -51,7 +67,7 @@ void gelu_cast(const Tensor &input, ...@@ -51,7 +67,7 @@ void gelu_cast(const Tensor &input,
); // NOLINT(*) ); // NOLINT(*)
} }
void geglu_cast(const Tensor &input, void geglu(const Tensor &input,
Tensor *output, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(input, "geglu_input"); CheckInputTensor(input, "geglu_input");
...@@ -61,18 +77,19 @@ void geglu_cast(const Tensor &input, ...@@ -61,18 +77,19 @@ void geglu_cast(const Tensor &input,
NVTE_CHECK(input.data.shape[0] == output->data.shape[0], NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0]."); "Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be twice than output shape[1]."); "Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType); constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, gelu<fp32, fp32>>( GatedActivationKernelLauncher<nvec, fp32, Empty, gelu<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr), reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr), reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0], output->data.shape[0],
output->data.shape[1], output->data.shape[1],
{},
stream); stream);
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
...@@ -91,19 +108,20 @@ void dgeglu(const Tensor &grad, ...@@ -91,19 +108,20 @@ void dgeglu(const Tensor &grad,
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0], NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0]."); "Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be twice than grad shape[1]."); "Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape, NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match."); "Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType); constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, gelu<fp32, fp32>, dgelu<fp32, fp32>>( DGatedActivationKernelLauncher<nvec, fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr), reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0], grad.data.shape[0],
grad.data.shape[1], grad.data.shape[1],
{},
stream); stream);
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
...@@ -116,7 +134,19 @@ void nvte_gelu(const NVTETensor input, ...@@ -116,7 +134,19 @@ void nvte_gelu(const NVTETensor input,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu); NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine; using namespace transformer_engine;
gelu_cast(*reinterpret_cast<const Tensor*>(input), gelu(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine;
dgelu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
} }
...@@ -126,7 +156,7 @@ void nvte_geglu(const NVTETensor input, ...@@ -126,7 +156,7 @@ void nvte_geglu(const NVTETensor input,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu); NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine; using namespace transformer_engine;
geglu_cast(*reinterpret_cast<const Tensor*>(input), geglu(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
} }
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include "../util/vectorized_pointwise.h"
#include "../util/math.h"
#include "../common.h"
namespace transformer_engine {
void relu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "relu_input");
CheckOutputTensor(*output, "relu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Empty, relu<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void drelu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "drelu_input");
CheckInputTensor(grad, "drelu_input_grad");
CheckOutputTensor(*output, "drelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Empty, drelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void reglu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "reglu_input");
CheckOutputTensor(*output, "reglu_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, Empty, relu<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dreglu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(grad, "dreglu_grad");
CheckInputTensor(input, "dreglu_input");
CheckOutputTensor(*output, "dreglu_output");
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0],
grad.data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_relu);
using namespace transformer_engine;
relu(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_drelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
using namespace transformer_engine;
drelu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_reglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
reglu(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine;
dreglu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include "../util/vectorized_pointwise.h"
#include "../util/math.h"
#include "../common.h"
namespace transformer_engine {
void swiglu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "geglu_input");
CheckOutputTensor(*output, "geglu_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, Empty, swish<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dswiglu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(grad, "dswiglu_grad");
CheckInputTensor(input, "dswiglu_input");
CheckOutputTensor(*output, "dswiglu_output");
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, Empty, swish<fp32, fp32>, dswish<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0],
grad.data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_swiglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
swiglu(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dswiglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine;
dswiglu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
...@@ -27,11 +27,23 @@ void nvte_gelu(const NVTETensor input, ...@@ -27,11 +27,23 @@ void nvte_gelu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute GELU activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for GELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute GeGLU of the input. /*! \brief Compute GeGLU of the input.
* *
* \param[in] input Input tensor of shape [N, H * 2]. * \param[in] input Input tensor of shape [N, H * 2].
* It computes GELU([N, :H]) x [N, H:]
* \param[in,out] output Output tensor of shape [N, H]. * \param[in,out] output Output tensor of shape [N, H].
* It computes GELU(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_geglu(const NVTETensor input, void nvte_geglu(const NVTETensor input,
...@@ -39,9 +51,9 @@ void nvte_geglu(const NVTETensor input, ...@@ -39,9 +51,9 @@ void nvte_geglu(const NVTETensor input,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute GeGLU gradient. /*! \brief Compute GeGLU gradient.
* \param[in] grad Input tensor of shape [N, H]. * \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Input tensor of shape [N, H * 2]. * \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H * 2]. * \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_dgeglu(const NVTETensor grad, void nvte_dgeglu(const NVTETensor grad,
...@@ -49,6 +61,72 @@ void nvte_dgeglu(const NVTETensor grad, ...@@ -49,6 +61,72 @@ void nvte_dgeglu(const NVTETensor grad,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute RELU activation of the input.
*
* \param[in] input Input tensor for RELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute RELU activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for RELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_drelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute SwiGLU activation of the input.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Swish(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_swiglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute SwiGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dswiglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute ReGLU activation of the input.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes ReLU(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_reglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute ReGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -712,7 +712,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param, ...@@ -712,7 +712,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]) * after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]); CType(in[current_in ^ 1][j].data.elt[k]);
} }
} }
...@@ -895,7 +895,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param, ...@@ -895,7 +895,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]) * after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]); CType(in[current_in ^ 1][j].data.elt[k]);
} }
} }
...@@ -1067,11 +1067,11 @@ dgeglu_cast_transpose_kernel(const IType * const input, ...@@ -1067,11 +1067,11 @@ dgeglu_cast_transpose_kernel(const IType * const input,
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]) * after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) * CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]); CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
gelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]); gelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {});
} }
} }
OVec out_trans_0[nvec_in]; // NOLINT(*) OVec out_trans_0[nvec_in]; // NOLINT(*)
...@@ -1264,11 +1264,11 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1264,11 +1264,11 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]) * after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) * CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]); CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
gelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]); gelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {});
} }
} }
OVec out_trans_0[nvec_in]; // NOLINT(*) OVec out_trans_0[nvec_in]; // NOLINT(*)
......
...@@ -8,16 +8,17 @@ ...@@ -8,16 +8,17 @@
#define TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
namespace transformer_engine { namespace transformer_engine {
namespace {
struct Empty {};
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType gelu(const IType val) { __device__ inline OType gelu(const IType val, const Empty&) {
const float cval = val; const float cval = val;
return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval))); return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dgelu(const IType val) { __device__ inline OType dgelu(const IType val, const Empty&) {
const float cval = val; const float cval = val;
const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) * return 0.5f * cval * ((1.f - tanh_out * tanh_out) *
...@@ -25,7 +26,42 @@ __device__ inline OType dgelu(const IType val) { ...@@ -25,7 +26,42 @@ __device__ inline OType dgelu(const IType val) {
0.5f * (1.f + tanh_out); 0.5f * (1.f + tanh_out);
} }
} // namespace template <typename OType, typename IType>
__device__ inline OType sigmoid(const IType val, const Empty&) {
const float cval = val;
return 1.f / (1.f + expf(-cval));
}
template <typename OType, typename IType>
__device__ inline OType dsigmoid(const IType val, const Empty& e) {
const float cval = val;
const float s = sigmoid<float, float>(cval, e);
return s * (1.f - s);
}
template <typename OType, typename IType>
__device__ inline OType swish(const IType val, const Empty& e) {
const float cval = val;
return cval * sigmoid<float, float>(cval, e);
}
template <typename OType, typename IType>
__device__ inline OType dswish(const IType val, const Empty& e) {
const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
}
template <typename OType, typename IType>
__device__ inline OType relu(IType value, const Empty &) {
return fmaxf(value, 0.f);
}
template <typename OType, typename IType>
__device__ inline OType drelu(IType value, const Empty &) {
return value > 0.f ? 1.f : 0.f;
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
...@@ -230,6 +230,66 @@ __global__ void unary_kernel(const InputType *input, ...@@ -230,6 +230,66 @@ __global__ void unary_kernel(const InputType *input,
} }
} }
template <int nvec, bool aligned,
typename ComputeType,
typename Param,
ComputeType (*OP)(ComputeType, const Param&),
typename InputType,
typename InputTypeGrad,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void unary_grad_kernel(const InputTypeGrad *grad,
const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *amax,
Param p,
const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedLoader<InputTypeGrad, nvec, aligned> grad_loader(grad, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0;
ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const size_t M = num_aligned_elements;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
grad_loader.load(tid, N);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
const ComputeType g = static_cast<ComputeType>(grad_loader.separate()[i]);
ComputeType temp = OP(val, p) * g;
if constexpr (is_fp8<OutputType>::value) {
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
temp = temp * s;
}
storer.separate()[i] = static_cast<OutputType>(temp);
}
storer.store(tid, N);
}
if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
}
namespace { namespace {
inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim, inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim,
...@@ -285,7 +345,7 @@ Alignment CheckAlignment(const size_t lead_dim, ...@@ -285,7 +345,7 @@ Alignment CheckAlignment(const size_t lead_dim,
} // namespace } // namespace
template <int nvec, typename Param, template <int nvec, typename Param,
fp32 (*OP)(fp32, const Param&), fp32 (*OP)(const fp32, const Param&),
typename InputType, typename InputType,
typename OutputType> typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input, void VectorizedUnaryKernelLauncher(const InputType *input,
...@@ -324,9 +384,52 @@ void VectorizedUnaryKernelLauncher(const InputType *input, ...@@ -324,9 +384,52 @@ void VectorizedUnaryKernelLauncher(const InputType *input,
} }
} }
template <int nvec, typename Param,
fp32 (*OP)(fp32, const Param&),
typename InputType,
typename InputTypeGrad,
typename OutputType>
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad,
const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *amax,
const size_t N,
const Param params,
cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, grad, output);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec,
sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);
switch (align) {
case Alignment::SAME_ALIGNED:
unary_grad_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
grad, input, output, scale, amax, params, N, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
unary_grad_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
grad, input, output, scale, amax, params, N, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
unary_grad_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
grad, input, output, scale, amax, params, N, N);
break;
}
}
}
}
template <int nvec, bool aligned, template <int nvec, bool aligned,
typename ComputeType, typename ComputeType,
ComputeType (*Activation)(ComputeType), typename Param,
ComputeType (*Activation)(const ComputeType, const Param&),
typename InputType, typename InputType,
typename OutputType> typename OutputType>
__launch_bounds__(unary_kernel_threads) __launch_bounds__(unary_kernel_threads)
...@@ -336,6 +439,7 @@ __global__ void gated_act_kernel(const InputType *input, ...@@ -336,6 +439,7 @@ __global__ void gated_act_kernel(const InputType *input,
ComputeType *amax, ComputeType *amax,
const size_t m, const size_t m,
const size_t n, const size_t n,
const Param p,
const size_t num_aligned_elements) { const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -359,7 +463,7 @@ __global__ void gated_act_kernel(const InputType *input, ...@@ -359,7 +463,7 @@ __global__ void gated_act_kernel(const InputType *input,
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]); const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
ComputeType temp = static_cast<ComputeType>(Activation(val) * val2); ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max); max = fmaxf(fabsf(temp), max);
...@@ -383,7 +487,8 @@ __global__ void gated_act_kernel(const InputType *input, ...@@ -383,7 +487,8 @@ __global__ void gated_act_kernel(const InputType *input,
template <int nvec, template <int nvec,
typename ComputeType, typename ComputeType,
ComputeType (*Activation)(ComputeType), typename Param,
ComputeType (*Activation)(const ComputeType, const Param&),
typename InputType, typename InputType,
typename OutputType> typename OutputType>
void GatedActivationKernelLauncher(const InputType *input, void GatedActivationKernelLauncher(const InputType *input,
...@@ -392,6 +497,7 @@ void GatedActivationKernelLauncher(const InputType *input, ...@@ -392,6 +497,7 @@ void GatedActivationKernelLauncher(const InputType *input,
fp32 *amax, fp32 *amax,
const size_t m, const size_t m,
const size_t n, const size_t n,
const Param &p,
cudaStream_t stream) { cudaStream_t stream) {
if (m != 0 && n != 0) { if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType)); size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType));
...@@ -402,17 +508,20 @@ void GatedActivationKernelLauncher(const InputType *input, ...@@ -402,17 +508,20 @@ void GatedActivationKernelLauncher(const InputType *input,
switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
gated_act_kernel<nvec, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>( gated_act_kernel<nvec, true, ComputeType, Param, Activation>
input, output, scale, amax, m, n, num_aligned_elements); <<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, m, n, p, num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
gated_act_kernel<nvec, false, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>( gated_act_kernel<nvec, false, ComputeType, Param, Activation>
input, output, scale, amax, m, n, num_aligned_elements); <<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, m, n, p, num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
gated_act_kernel<1, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>( gated_act_kernel<1, true, ComputeType, Param, Activation>
input, output, scale, amax, m, n, n); <<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, m, n, p, n);
break; break;
} }
} }
...@@ -421,8 +530,9 @@ void GatedActivationKernelLauncher(const InputType *input, ...@@ -421,8 +530,9 @@ void GatedActivationKernelLauncher(const InputType *input,
template <int nvec, bool aligned, template <int nvec, bool aligned,
typename ComputeType, typename ComputeType,
ComputeType (*Activation)(ComputeType), typename Param,
ComputeType (*Dactivation)(ComputeType), ComputeType (*Activation)(const ComputeType, const Param&),
ComputeType (*Dactivation)(const ComputeType, const Param&),
typename InputType, typename InputType,
typename OutputType> typename OutputType>
__launch_bounds__(unary_kernel_threads) __launch_bounds__(unary_kernel_threads)
...@@ -431,6 +541,7 @@ __global__ void dgated_act_kernel(const InputType *grad, ...@@ -431,6 +541,7 @@ __global__ void dgated_act_kernel(const InputType *grad,
OutputType *output, OutputType *output,
const size_t m, const size_t m,
const size_t n, const size_t n,
const Param p,
const size_t num_aligned_elements) { const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -454,8 +565,8 @@ __global__ void dgated_act_kernel(const InputType *grad, ...@@ -454,8 +565,8 @@ __global__ void dgated_act_kernel(const InputType *grad,
const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]); const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]);
const ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]); const ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
ComputeType after_dgelu = Dactivation(gelu_in) * grad_val * gate_in; ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in); ComputeType after_dgate = grad_val * Activation(gelu_in, p);
storer0.separate()[i] = static_cast<OutputType>(after_dgelu); storer0.separate()[i] = static_cast<OutputType>(after_dgelu);
storer1.separate()[i] = static_cast<OutputType>(after_dgate); storer1.separate()[i] = static_cast<OutputType>(after_dgate);
...@@ -467,8 +578,9 @@ __global__ void dgated_act_kernel(const InputType *grad, ...@@ -467,8 +578,9 @@ __global__ void dgated_act_kernel(const InputType *grad,
template <int nvec, template <int nvec,
typename ComputeType, typename ComputeType,
ComputeType (*Activation)(ComputeType), typename Param,
ComputeType (*Dactivation)(ComputeType), ComputeType (*Activation)(const ComputeType, const Param&),
ComputeType (*Dactivation)(const ComputeType, const Param&),
typename InputType, typename InputType,
typename OutputType> typename OutputType>
void DGatedActivationKernelLauncher(const InputType *grad, void DGatedActivationKernelLauncher(const InputType *grad,
...@@ -476,6 +588,7 @@ void DGatedActivationKernelLauncher(const InputType *grad, ...@@ -476,6 +588,7 @@ void DGatedActivationKernelLauncher(const InputType *grad,
OutputType *output, OutputType *output,
const size_t m, const size_t m,
const size_t n, const size_t n,
const Param &p,
cudaStream_t stream) { cudaStream_t stream) {
if (m != 0 && n != 0) { if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec,
...@@ -487,17 +600,17 @@ void DGatedActivationKernelLauncher(const InputType *grad, ...@@ -487,17 +600,17 @@ void DGatedActivationKernelLauncher(const InputType *grad,
switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
dgated_act_kernel<nvec, true, ComputeType, Activation, Dactivation> dgated_act_kernel<nvec, true, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, num_aligned_elements); <<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
dgated_act_kernel<nvec, false, ComputeType, Activation, Dactivation> dgated_act_kernel<nvec, false, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, num_aligned_elements); <<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
dgated_act_kernel<1, true, ComputeType, Activation, Dactivation> dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, n); <<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, n);
break; break;
} }
} }
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for c++ extensions"""
from transformer_engine_extensions import *
from .fused_attn import *
from .gemm import *
from .transpose import *
from .activation import *
from .normalization import *
from .cast import *
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for activation extensions"""
from typing import Union
import torch
import transformer_engine_extensions as tex
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu']
def gelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""GeLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
return torch.ops.tex_ts.gelu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
otype,
)
def relu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""ReLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
return torch.ops.tex_ts.relu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
otype,
)
def geglu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""GeGLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
return torch.ops.tex_ts.geglu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
otype,
)
def reglu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""ReGLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
return torch.ops.tex_ts.reglu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
otype,
)
def swiglu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""SwiGLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
return torch.ops.tex_ts.swiglu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
otype,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for cast extensions"""
from typing import Optional, Union
import torch
import transformer_engine_extensions as tex
__all__ = ['cast_to_fp8',
'cast_from_fp8']
def cast_to_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
out: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
"""Cast input to FP8"""
if out is not None:
tex.cast_to_fp8_noalloc(
inp,
fp8_meta_tensor.scale[fp8_tensor],
out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype
)
return None
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
def cast_from_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
itype: tex.DType,
otype: tex.DType,
) -> torch.Tensor:
"""Cast input from FP8"""
return torch.ops.tex_ts.cast_from_fp8_ts(
inp,
fp8_meta_tensor.scale_inv,
fp8_tensor,
itype,
otype,
)
...@@ -2,12 +2,18 @@ ...@@ -2,12 +2,18 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""TE FP8 extensions and GEMMs""" """Python interface for fused attention extensions"""
import math import math
from typing import Optional, Tuple, List, Union from typing import Tuple, List, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from .constants import TE_DType
__all__ = ['fused_attn_fwd_qkvpacked',
'fused_attn_bwd_qkvpacked',
'fused_attn_fwd_kvpacked',
'fused_attn_bwd_kvpacked']
TORCH_DType = { TORCH_DType = {
tex.DType.kFloat8E4M3: torch.uint8, tex.DType.kFloat8E4M3: torch.uint8,
...@@ -18,11 +24,13 @@ TORCH_DType = { ...@@ -18,11 +24,13 @@ TORCH_DType = {
tex.DType.kInt32: torch.int32, tex.DType.kInt32: torch.int32,
} }
def check_tensor(x: torch.Tensor): def check_tensor(x: torch.Tensor):
"""Check tensor properties.""" """Check tensor properties."""
assert (x.is_cuda and x.is_contiguous() assert (x.is_cuda and x.is_contiguous()
), "Tensor should be a GPU tensor and contiguous." ), "Tensor should be a GPU tensor and contiguous."
def check_qkv(qkv: torch.Tensor, dtype: torch.dtype): def check_qkv(qkv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties.""" """Check tensor properties."""
check_tensor(qkv) check_tensor(qkv)
...@@ -32,6 +40,7 @@ def check_qkv(qkv: torch.Tensor, dtype: torch.dtype): ...@@ -32,6 +40,7 @@ def check_qkv(qkv: torch.Tensor, dtype: torch.dtype):
), """QKV should be in [total_seqs, 3, num_heads, head_dim] shape ), """QKV should be in [total_seqs, 3, num_heads, head_dim] shape
and {dtype} dtype.""" and {dtype} dtype."""
def check_q(q: torch.Tensor, dtype: torch.dtype): def check_q(q: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties.""" """Check tensor properties."""
check_tensor(q) check_tensor(q)
...@@ -40,6 +49,7 @@ def check_q(q: torch.Tensor, dtype: torch.dtype): ...@@ -40,6 +49,7 @@ def check_q(q: torch.Tensor, dtype: torch.dtype):
), """Q should be in [total_seqs, num_heads, head_dim] shape ), """Q should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype.""" and {dtype} dtype."""
def check_kv(kv: torch.Tensor, dtype: torch.dtype): def check_kv(kv: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties.""" """Check tensor properties."""
check_tensor(kv) check_tensor(kv)
...@@ -49,6 +59,7 @@ def check_kv(kv: torch.Tensor, dtype: torch.dtype): ...@@ -49,6 +59,7 @@ def check_kv(kv: torch.Tensor, dtype: torch.dtype):
), """KV should be in [total_seqs, 2, num_heads, head_dim] shape ), """KV should be in [total_seqs, 2, num_heads, head_dim] shape
and {dtype} dtype.""" and {dtype} dtype."""
def check_o(o: torch.Tensor, dtype: torch.dtype): def check_o(o: torch.Tensor, dtype: torch.dtype):
"""Check tensor properties.""" """Check tensor properties."""
check_tensor(o) check_tensor(o)
...@@ -57,6 +68,7 @@ def check_o(o: torch.Tensor, dtype: torch.dtype): ...@@ -57,6 +68,7 @@ def check_o(o: torch.Tensor, dtype: torch.dtype):
), """O and dO should be in [total_seqs, num_heads, head_dim] shape ), """O and dO should be in [total_seqs, num_heads, head_dim] shape
and {dtype} dtype.""" and {dtype} dtype."""
def check_stats(stats: torch.Tensor, b: int, h: int, s: int): def check_stats(stats: torch.Tensor, b: int, h: int, s: int):
"""Check tensor properties.""" """Check tensor properties."""
check_tensor(stats) check_tensor(stats)
...@@ -66,6 +78,7 @@ def check_stats(stats: torch.Tensor, b: int, h: int, s: int): ...@@ -66,6 +78,7 @@ def check_stats(stats: torch.Tensor, b: int, h: int, s: int):
), """M and ZInv should be in [batch_size, num_heads, max_seqlen_q, 1] ), """M and ZInv should be in [batch_size, num_heads, max_seqlen_q, 1]
shape and float32 dtype.""" shape and float32 dtype."""
def check_cu_seqlens(cu_seqlens: torch.Tensor): def check_cu_seqlens(cu_seqlens: torch.Tensor):
"""Check tensor properties.""" """Check tensor properties."""
check_tensor(cu_seqlens) check_tensor(cu_seqlens)
...@@ -81,6 +94,7 @@ def check_scalar(scalar: torch.Tensor): ...@@ -81,6 +94,7 @@ def check_scalar(scalar: torch.Tensor):
and scalar.numel() == 1 and scalar.numel() == 1
), "amax/scale/descale tensors should be scalars in float32 dtype." ), "amax/scale/descale tensors should be scalars in float32 dtype."
def check_rng_state(rng_state: torch.Tensor): def check_rng_state(rng_state: torch.Tensor):
"""Check tensor properties.""" """Check tensor properties."""
check_tensor(rng_state) check_tensor(rng_state)
...@@ -88,6 +102,7 @@ def check_rng_state(rng_state: torch.Tensor): ...@@ -88,6 +102,7 @@ def check_rng_state(rng_state: torch.Tensor):
and rng_state.numel() == 2 and rng_state.numel() == 2
), "rng_state should be [seed, offset] and in int64 dtype." ), "rng_state should be [seed, offset] and in int64 dtype."
def fused_attn_fwd_qkvpacked( def fused_attn_fwd_qkvpacked(
is_training: bool, is_training: bool,
max_seqlen: int, max_seqlen: int,
...@@ -749,450 +764,3 @@ def fused_attn_bwd_kvpacked( ...@@ -749,450 +764,3 @@ def fused_attn_bwd_kvpacked(
if bias_type == "no_bias": if bias_type == "no_bias":
return output_tensors[:2] return output_tensors[:2]
return output_tensors return output_tensors
def fp8_gemm(
A: torch.Tensor,
A_scale_inv: torch.Tensor,
A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
A_dtype: tex.DType,
B: torch.Tensor,
B_scale_inv: torch.Tensor,
B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
B_dtype: tex.DType,
out_dtype: torch.dtype,
workspace: torch.Tensor,
gelu: bool = False,
accumulate: bool = False,
out: Optional[torch.Tensor] = None,
out_index = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
ub_algo: tex.UbufOverlapAlgo = None,
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
extra_output_tensor: torch.Tensor = None,
) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs."""
empty_tensor = torch.Tensor()
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None
return_output = False
if out is None:
out = torch.empty(
B.shape[0],
A.shape[0],
dtype=out_dtype,
device="cuda",
)
return_output = True
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype
if gelu:
gelu_input = torch.empty_like(out, dtype=bias_dtype)
else:
gelu_input = empty_tensor
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
args = (
A,
A_scale_inv,
A_fp8_tensor,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor,
B_dtype,
False, # transb
out,
empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index],
out_dtype,
empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
bias if use_bias else empty_tensor,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator)
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args)
if return_output:
if gelu:
return out, gelu_input
return out
if gelu:
return gelu_input
return None
def gemm(
A: torch.Tensor,
B: torch.Tensor,
dtype: torch.dtype,
workspace: torch.Tensor,
gelu: bool = False,
gelu_input: Optional[torch.Tensor] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
ub_algo: tex.UbufOverlapAlgo = None,
ub: tex.UbufCommOverlap = None,
extra_output_tensor: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
empty_tensor = torch.Tensor()
fp8_index = -1 # dummy index
return_output = False
if out is None:
out = torch.empty(
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
dtype=dtype,
device="cuda",
)
return_output = True
if gelu and not grad:
gelu_input = torch.empty_like(out, dtype=dtype)
elif not gelu:
gelu_input = empty_tensor
if grad and use_bias:
grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda")
else:
grad_bias = empty_tensor
bias = bias if use_bias else empty_tensor
assert A.dtype == dtype and B.dtype == dtype, \
f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out.dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
else:
bias_dtype = output_dtype
args = (
A,
empty_tensor,
fp8_index,
input_dtype,
transa,
B,
empty_tensor,
fp8_index,
input_dtype,
transb,
out,
empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias,
bias_dtype,
gelu_input,
grad,
workspace,
workspace.shape[0],
accumulate,
False, # use_split_accumulator
)
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (False, extra_output_tensor,))
_ = fn(*args)
if return_output:
return out, grad_bias, gelu_input
return None, grad_bias, gelu_input
def fp8_cast_transpose_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
cast_out: Optional[torch.Tensor] = None,
transpose_out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
"""Cast + Transpose with FP8 output"""
return_outputs = False
if cast_out is None or transpose_out is None:
cast_out = torch.empty_like(inp, dtype=torch.uint8)
transpose_out = torch.empty(
inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8
)
return_outputs = True
tex.fused_cast_transpose(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
cast_out,
transpose_out,
otype,
)
if return_outputs:
return cast_out, transpose_out
return None
def fp8_cast_transpose_bgrad_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD with FP8 output"""
return tex.fused_cast_transpose_bgrad(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
)
def fp8_transpose_bgrad_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
grad_bias_type: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transpose + BGRAD with FP8 output"""
return tex.fused_fp8_transpose_bgrad(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
TE_DType[grad_bias_type],
)
def fp8_cast_transpose_bgrad_dgelu_fused(
grad_output: torch.Tensor,
gelu_input: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD + DGELU with FP8 output"""
return tex.fused_cast_transpose_bgrad_dgelu(
grad_output,
gelu_input,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
)
def fp8_gelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""GeLU with FP8 output"""
return torch.ops.tex_ts.fp8_gelu_ts(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
def layernorm_fwd_fp8(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma: bool,
ln_out: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output"""
if ln_out is not None:
return tex.layernorm_fwd_fp8_noalloc(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale[fp8_tensor],
ln_out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
zero_centered_gamma
)
return tex.layernorm_fwd_fp8(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
zero_centered_gamma
)
def layernorm_fwd_fp8_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
zero_centered_gamma,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
This version of layernorm_fwd_fp8 is specialized for inference, and returns
only the normalized output.
"""
ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
zero_centered_gamma)
return ret
def layernorm_fwd_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
zero_centered_gamma: bool,
) -> torch.Tensor:
"""LayerNorm with FP8 output"""
return torch.ops.tex_ts.layernorm_fwd_inf_ts(
inp,
weight,
bias,
eps,
zero_centered_gamma,
)
def cast_to_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
out: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
"""Cast input to FP8"""
if out is not None:
tex.cast_to_fp8_noalloc(
inp,
fp8_meta_tensor.scale[fp8_tensor],
out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype
)
return None
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
def cast_from_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
itype: tex.DType,
otype: tex.DType,
) -> torch.Tensor:
"""Cast input from FP8"""
return torch.ops.tex_ts.cast_from_fp8_ts(
inp,
fp8_meta_tensor.scale_inv,
fp8_tensor,
itype,
otype,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for GEMM extensions"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
from ..constants import TE_DType
__all__ = ['gemm', 'fp8_gemm']
def fp8_gemm(
A: torch.Tensor,
A_scale_inv: torch.Tensor,
A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
A_dtype: tex.DType,
B: torch.Tensor,
B_scale_inv: torch.Tensor,
B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
B_dtype: tex.DType,
out_dtype: torch.dtype,
workspace: torch.Tensor,
gelu: bool = False,
accumulate: bool = False,
out: Optional[torch.Tensor] = None,
out_index = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
ub_algo: tex.UbufOverlapAlgo = None,
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
extra_output_tensor: torch.Tensor = None,
) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs."""
empty_tensor = torch.Tensor()
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None
return_output = False
if out is None:
out = torch.empty(
B.shape[0],
A.shape[0],
dtype=out_dtype,
device="cuda",
)
return_output = True
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype
if gelu:
gelu_input = torch.empty_like(out, dtype=bias_dtype)
else:
gelu_input = empty_tensor
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
args = (
A,
A_scale_inv,
A_fp8_tensor,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor,
B_dtype,
False, # transb
out,
empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index],
out_dtype,
empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
bias if use_bias else empty_tensor,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator)
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args)
if return_output:
if gelu:
return out, gelu_input
return out
if gelu:
return gelu_input
return None
def gemm(
A: torch.Tensor,
B: torch.Tensor,
dtype: torch.dtype,
workspace: torch.Tensor,
gelu: bool = False,
gelu_input: Optional[torch.Tensor] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
ub_algo: tex.UbufOverlapAlgo = None,
ub: tex.UbufCommOverlap = None,
extra_output_tensor: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
empty_tensor = torch.Tensor()
fp8_index = -1 # dummy index
return_output = False
if out is None:
out = torch.empty(
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
dtype=dtype,
device="cuda",
)
return_output = True
if gelu and not grad:
gelu_input = torch.empty_like(out, dtype=dtype)
elif not gelu:
gelu_input = empty_tensor
if grad and use_bias:
grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda")
else:
grad_bias = empty_tensor
bias = bias if use_bias else empty_tensor
assert A.dtype == dtype and B.dtype == dtype, \
f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out.dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
else:
bias_dtype = output_dtype
args = (
A,
empty_tensor,
fp8_index,
input_dtype,
transa,
B,
empty_tensor,
fp8_index,
input_dtype,
transb,
out,
empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias,
bias_dtype,
gelu_input,
grad,
workspace,
workspace.shape[0],
accumulate,
False, # use_split_accumulator
)
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (False, extra_output_tensor,))
_ = fn(*args)
if return_output:
return out, grad_bias, gelu_input
return None, grad_bias, gelu_input
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for normalization extensions"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
__all__ = ['layernorm_fwd_fp8',
'layernorm_fwd_fp8_inf',
'layernorm_fwd_inf']
def layernorm_fwd_fp8(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma: bool,
ln_out: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output"""
if ln_out is not None:
return tex.layernorm_fwd_fp8_noalloc(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale[fp8_tensor],
ln_out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
zero_centered_gamma
)
return tex.layernorm_fwd_fp8(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
zero_centered_gamma
)
def layernorm_fwd_fp8_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
zero_centered_gamma,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
This version of layernorm_fwd_fp8 is specialized for inference, and returns
only the normalized output.
"""
ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
zero_centered_gamma)
return ret
def layernorm_fwd_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
zero_centered_gamma: bool,
) -> torch.Tensor:
"""LayerNorm with FP8 output"""
return torch.ops.tex_ts.layernorm_fwd_inf_ts(
inp,
weight,
bias,
eps,
zero_centered_gamma,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for transpose extensions"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
from ..constants import TE_DType
__all__ = ['fp8_cast_transpose_fused',
'fp8_cast_transpose_bgrad_fused',
'fp8_cast_transpose_bgrad_dgelu_fused',
'fp8_transpose_bgrad_fused']
def fp8_cast_transpose_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
cast_out: Optional[torch.Tensor] = None,
transpose_out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
"""Cast + Transpose with FP8 output"""
return_outputs = False
if cast_out is None or transpose_out is None:
cast_out = torch.empty_like(inp, dtype=torch.uint8)
transpose_out = torch.empty(
inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8
)
return_outputs = True
tex.fused_cast_transpose(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
cast_out,
transpose_out,
otype,
)
if return_outputs:
return cast_out, transpose_out
return None
def fp8_cast_transpose_bgrad_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD with FP8 output"""
return tex.fused_cast_transpose_bgrad(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
)
def fp8_transpose_bgrad_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
grad_bias_type: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transpose + BGRAD with FP8 output"""
return tex.fused_fp8_transpose_bgrad(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
TE_DType[grad_bias_type],
)
def fp8_cast_transpose_bgrad_dgelu_fused(
grad_output: torch.Tensor,
gelu_input: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD + DGELU with FP8 output"""
return tex.fused_cast_transpose_bgrad_dgelu(
grad_output,
gelu_input,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
)
...@@ -172,5 +172,4 @@ at::Tensor allocateTorchTensor(int M, ...@@ -172,5 +172,4 @@ at::Tensor allocateTorchTensor(int M,
transformer_engine::DType dtype transformer_engine::DType dtype
); );
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
...@@ -1131,7 +1131,7 @@ at::Tensor fp8_transpose(at::Tensor input, ...@@ -1131,7 +1131,7 @@ at::Tensor fp8_transpose(at::Tensor input,
} }
at::Tensor fp8_gelu(at::Tensor input, at::Tensor gelu(at::Tensor input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
...@@ -1139,15 +1139,16 @@ at::Tensor fp8_gelu(at::Tensor input, ...@@ -1139,15 +1139,16 @@ at::Tensor fp8_gelu(at::Tensor input,
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0)); size_t N = static_cast<size_t>(input.size(-1));
size_t N = static_cast<size_t>(input.size(1)); size_t M = input.numel() / N;
auto output = auto output =
allocateTorchTensor(input.size(0), allocateTorchTensor(M,
input.size(1), N,
DType::kByte); otype);
auto input_cu = makeTransformerEngineTensor(input); auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(), amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr()); scale_inv.data_ptr());
...@@ -1157,6 +1158,238 @@ at::Tensor fp8_gelu(at::Tensor input, ...@@ -1157,6 +1158,238 @@ at::Tensor fp8_gelu(at::Tensor input,
return output; return output;
} }
at::Tensor dgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor relu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = static_cast<size_t>(input.numel()) / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_relu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor drelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_drelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor geglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N / 2,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_geglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dgeglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dgeglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor reglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N / 2,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_reglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dreglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dreglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor swiglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N / 2,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_swiglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dswiglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x, const at::Tensor &x,
...@@ -1982,7 +2215,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1982,7 +2215,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV"); "Fused Attention FP8/BF16/FP16 BWD with packed KV");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output"); m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output");
m.def("geglu", &geglu, "GeGLU with FP8 output");
m.def("reglu", &reglu, "ReGLU with FP8 output");
m.def("swiglu", &swiglu, "SwiGLU with FP8 output");
m.def("dgelu", &dgelu, "Backward of GeLU");
m.def("drelu", &drelu, "Backward of ReLU");
m.def("dgeglu", &dgeglu, "Backward of GeGLU");
m.def("dreglu", &dreglu, "Backward of ReGLU");
m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment