Unverified Commit 53a3bc35 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[Pytorch] Added squared ReLU implementation (#846)



* added squared relu in te-torch
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 818c5318
...@@ -68,7 +68,7 @@ batch_sizes = [1, 2] ...@@ -68,7 +68,7 @@ batch_sizes = [1, 2]
all_boolean = [True, False] all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu"] all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
all_normalizations = ["LayerNorm", "RMSNorm"] all_normalizations = ["LayerNorm", "RMSNorm"]
...@@ -333,12 +333,17 @@ class TorchQuickGELU(nn.Module): ...@@ -333,12 +333,17 @@ class TorchQuickGELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input) return input * torch.sigmoid(1.702 * input)
class TorchSquaredRELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input > 0) * input * input
_supported_act = {'geglu' : nn.GELU(approximate="tanh"), _supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'gelu' : nn.GELU(approximate="tanh"), 'gelu' : nn.GELU(approximate="tanh"),
'reglu' : nn.ReLU(), 'reglu' : nn.ReLU(),
'relu' : nn.ReLU(), 'relu' : nn.ReLU(),
'swiglu' : nn.SiLU(), 'swiglu' : nn.SiLU(),
'qgelu' : TorchQuickGELU()} 'qgelu' : TorchQuickGELU(),
'srelu' : TorchSquaredRELU()}
class TorchGLU(nn.Module): class TorchGLU(nn.Module):
......
...@@ -113,7 +113,7 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher ...@@ -113,7 +113,7 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
all_boolean = [True, False] all_boolean = [True, False]
batch_sizes_with_zero = [0, 1, 2] batch_sizes_with_zero = [0, 1, 2]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"]
all_normalizations = ["LayerNorm", "RMSNorm"] all_normalizations = ["LayerNorm", "RMSNorm"]
def _disable_wgrads(block): def _disable_wgrads(block):
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu'] __all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu', 'srelu']
def gelu( def gelu(
...@@ -166,3 +166,28 @@ def qgelu( ...@@ -166,3 +166,28 @@ def qgelu(
fp8_tensor, fp8_tensor,
otype, otype,
) )
def srelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""ReLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
return torch.ops.tex_ts.srelu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
otype,
)
...@@ -340,6 +340,13 @@ at::Tensor qgelu(at::Tensor input, ...@@ -340,6 +340,13 @@ at::Tensor qgelu(at::Tensor input,
transformer_engine::DType otype transformer_engine::DType otype
); );
at::Tensor srelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor dgelu(at::Tensor grad, at::Tensor dgelu(at::Tensor grad,
at::Tensor input, at::Tensor input,
transformer_engine::DType otype transformer_engine::DType otype
...@@ -370,6 +377,11 @@ at::Tensor dqgelu(at::Tensor grad, ...@@ -370,6 +377,11 @@ at::Tensor dqgelu(at::Tensor grad,
transformer_engine::DType otype transformer_engine::DType otype
); );
at::Tensor dsrelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
/*************************************************************************************************** /***************************************************************************************************
* LayerNorm * LayerNorm
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -317,3 +317,56 @@ at::Tensor dqgelu(at::Tensor grad, ...@@ -317,3 +317,56 @@ at::Tensor dqgelu(at::Tensor grad,
return output; return output;
} }
at::Tensor srelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = static_cast<size_t>(input.numel()) / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_srelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dsrelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dsrelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
...@@ -78,12 +78,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -78,12 +78,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("reglu", &reglu, "ReGLU with FP8 output"); m.def("reglu", &reglu, "ReGLU with FP8 output");
m.def("swiglu", &swiglu, "SwiGLU with FP8 output"); m.def("swiglu", &swiglu, "SwiGLU with FP8 output");
m.def("qgelu", &qgelu, "QuickGELU with FP8 output"); m.def("qgelu", &qgelu, "QuickGELU with FP8 output");
m.def("srelu", &srelu, "Squared ReLU with FP8 output");
m.def("dgelu", &dgelu, "Backward of GeLU"); m.def("dgelu", &dgelu, "Backward of GeLU");
m.def("drelu", &drelu, "Backward of ReLU"); m.def("drelu", &drelu, "Backward of ReLU");
m.def("dgeglu", &dgeglu, "Backward of GeGLU"); m.def("dgeglu", &dgeglu, "Backward of GeGLU");
m.def("dreglu", &dreglu, "Backward of ReGLU"); m.def("dreglu", &dreglu, "Backward of ReGLU");
m.def("dswiglu", &dswiglu, "Backward of SwiGLU"); m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
m.def("dqgelu", &dqgelu, "Backward of QuickGELU"); m.def("dqgelu", &dqgelu, "Backward of QuickGELU");
m.def("dsrelu", &dsrelu, "Backward of Squared ReLU");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
......
...@@ -285,6 +285,41 @@ at::Tensor qgelu_ts(at::Tensor input, ...@@ -285,6 +285,41 @@ at::Tensor qgelu_ts(at::Tensor input,
return output; return output;
} }
at::Tensor srelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = srelu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor te_gemm_ts(at::Tensor A, at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
int64_t A_fp8_tensor, int64_t A_fp8_tensor,
...@@ -458,6 +493,7 @@ TORCH_LIBRARY(tex_ts, m) { ...@@ -458,6 +493,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("reglu_ts", &reglu_ts); m.def("reglu_ts", &reglu_ts);
m.def("swiglu_ts", &swiglu_ts); m.def("swiglu_ts", &swiglu_ts);
m.def("qgelu_ts", &qgelu_ts); m.def("qgelu_ts", &qgelu_ts);
m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts); m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
......
...@@ -64,7 +64,8 @@ def _act_func(activation: str): ...@@ -64,7 +64,8 @@ def _act_func(activation: str):
'geglu': (tex.geglu, tex.dgeglu), 'geglu': (tex.geglu, tex.dgeglu),
'reglu': (tex.reglu, tex.dreglu), 'reglu': (tex.reglu, tex.dreglu),
'swiglu': (tex.swiglu, tex.dswiglu), 'swiglu': (tex.swiglu, tex.dswiglu),
'qgelu': (tex.qgelu, tex.dqgelu) 'qgelu': (tex.qgelu, tex.dqgelu),
'srelu': (tex.srelu, tex.dsrelu),
} }
if activation not in funcs: if activation not in funcs:
raise NotImplementedError("Activation type " + activation + " is not supported!") raise NotImplementedError("Activation type " + activation + " is not supported!")
...@@ -1194,7 +1195,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1194,7 +1195,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied. type of normalization applied.
activation : str, default = 'gelu' activation : str, default = 'gelu'
activation function used. activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu'. Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu', 'srelu'.
init_method : Callable, default = `None` init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`. used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
......
...@@ -163,7 +163,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -163,7 +163,7 @@ class TransformerLayer(torch.nn.Module):
if set to `False`, the transformer layer will not learn any additive biases. if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu' activation : str, default = 'gelu'
Type of activation used in MLP block. Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu' and 'qgelu'. Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment