Unverified Commit 0e116d51 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

QuickGELU activation from HuggingFace/Transformers (#475)



* Added QuickGELUActivation from HuggingFace/Transformers to common and pytorch
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Removing 'qgelu' from double-size activations list in LayerNormMLP.
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* indent fix
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent d5c088da
......@@ -61,7 +61,7 @@ batch_sizes = [1, 2]
all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
......@@ -304,12 +304,16 @@ class TorchMHA(nn.Module):
output = output[0]
return output
class TorchQuickGELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input)
_supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'gelu' : nn.GELU(approximate="tanh"),
'reglu' : nn.ReLU(),
'relu' : nn.ReLU(),
'swiglu' : nn.SiLU()}
'swiglu' : nn.SiLU(),
'qgelu' : TorchQuickGELU()}
class TorchGLU(nn.Module):
......
......@@ -127,6 +127,57 @@ void dgeglu(const Tensor &grad,
); // NOLINT(*)
}
void qgelu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "qgelu_input");
CheckOutputTensor(*output, "qgelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Empty, qgelu<fp32, fp32> >(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
Empty(),
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dqgelu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "dqgelu_input");
CheckInputTensor(grad, "dqgelu_input_grad");
CheckOutputTensor(*output, "dqgelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Empty, dqgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_gelu(const NVTETensor input,
......@@ -172,3 +223,25 @@ void nvte_dgeglu(const NVTETensor grad,
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine;
qgelu(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dqgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dqgelu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -47,8 +47,8 @@ void nvte_dgelu(const NVTETensor grad,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute GeGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
......@@ -113,8 +113,8 @@ void nvte_dswiglu(const NVTETensor grad,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_reglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute ReGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
......@@ -123,9 +123,31 @@ void nvte_reglu(const NVTETensor input,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute QuickGELU activation of the input.
*
* \param[in] input Input tensor for QuickGELU activation.
* \param[in,out] output Output tensor. Approximates GELU as input x sigmoid(1.702 x input).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute QuickGELU activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for QuickGELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -39,6 +39,19 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
return s * (1.f - s);
}
template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e);
}
template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val;
return cval * dsigmoid<float, float>(1.702f * cval, e) +
sigmoid<float, float>(1.702f * cval, e);
}
template <typename OType, typename IType>
__device__ inline OType swish(const IType val, const Empty& e) {
const float cval = val;
......
......@@ -8,7 +8,7 @@ import torch
import transformer_engine_extensions as tex
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu']
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu']
def gelu(
......@@ -140,3 +140,29 @@ def swiglu(
fp8_tensor,
otype,
)
def qgelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""QuickGELU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
return torch.ops.tex_ts.qgelu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
otype,
)
......@@ -302,6 +302,13 @@ at::Tensor swiglu(at::Tensor input,
transformer_engine::DType otype
);
at::Tensor qgelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor dgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
......@@ -327,6 +334,11 @@ at::Tensor dswiglu(at::Tensor grad,
transformer_engine::DType otype
);
at::Tensor dqgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
......
......@@ -265,3 +265,55 @@ at::Tensor dswiglu(at::Tensor grad,
return output;
}
at::Tensor qgelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_qgelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dqgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dqgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
......@@ -72,11 +72,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("geglu", &geglu, "GeGLU with FP8 output");
m.def("reglu", &reglu, "ReGLU with FP8 output");
m.def("swiglu", &swiglu, "SwiGLU with FP8 output");
m.def("qgelu", &qgelu, "QuickGELU with FP8 output");
m.def("dgelu", &dgelu, "Backward of GeLU");
m.def("drelu", &drelu, "Backward of ReLU");
m.def("dgeglu", &dgeglu, "Backward of GeGLU");
m.def("dreglu", &dreglu, "Backward of ReGLU");
m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
m.def("dqgelu", &dqgelu, "Backward of QuickGELU");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
......
......@@ -247,6 +247,40 @@ at::Tensor swiglu_ts(at::Tensor input,
return output;
}
at::Tensor qgelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = qgelu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse,
......@@ -406,6 +440,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("geglu_ts", &geglu_ts);
m.def("reglu_ts", &reglu_ts);
m.def("swiglu_ts", &swiglu_ts);
m.def("qgelu_ts", &qgelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
......
......@@ -63,6 +63,7 @@ def _act_func(activation: str):
'geglu': (tex.geglu, tex.dgeglu),
'reglu': (tex.reglu, tex.dreglu),
'swiglu': (tex.swiglu, tex.dswiglu),
'qgelu': (tex.qgelu, tex.dqgelu)
}
if activation not in funcs:
raise NotImplementedError("Activation type " + activation + " is not supported!")
......@@ -1078,7 +1079,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu'.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
......
......@@ -163,7 +163,7 @@ class TransformerLayer(torch.nn.Module):
if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'.
Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu' and 'qgelu'.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment