Unverified Commit c67bb2fc authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Adding other activation types to LayerNormMLP (#265)



* Added ReLU and GLU variants to common
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* pyTorch changes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* PyTorch C++ lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix storage errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Compute bgrad
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix numerical tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX export tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent df6f347f
......@@ -167,14 +167,73 @@ at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype
);
/***************************************************************************************************
* Activations
**************************************************************************************************/
at::Tensor gelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor relu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor geglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor reglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor swiglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor dgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor drelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor dgeglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor dreglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor fp8_gelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
at::Tensor dswiglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
......
......@@ -47,18 +47,178 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input,
return output;
}
at::Tensor fp8_gelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
at::Tensor gelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = fp8_gelu(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = gelu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor relu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = relu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor reglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = reglu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor geglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = geglu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor swiglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = swiglu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
......@@ -171,7 +331,11 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("fp8_gelu_ts", &fp8_gelu_ts);
m.def("gelu_ts", &gelu_ts);
m.def("relu_ts", &relu_ts);
m.def("geglu_ts", &geglu_ts);
m.def("reglu_ts", &reglu_ts);
m.def("swiglu_ts", &swiglu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
......
......@@ -254,7 +254,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::fp8_gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
......@@ -136,6 +136,9 @@ class TransformerLayer(torch.nn.Module):
using :attr:`fuse_qkv_params=False`.
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'.
Parallelism parameters
----------------------
......@@ -214,6 +217,7 @@ class TransformerLayer(torch.nn.Module):
qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False,
bias: bool = True,
activation: str = 'gelu'
) -> None:
super().__init__()
......@@ -316,9 +320,11 @@ class TransformerLayer(torch.nn.Module):
bias=bias,
)
# LayerNorm -> gelu(Linear + Bias) -> Linear
# LayerNorm -> activation(Linear + Bias) -> Linear
# parallel_mode not supported for LayerNormMLP,
# FC1 is CPL and FC2 is RPL
# In the case of GLU activation, FC1 handles both
# Linear layers before the activation
self.layernorm_mlp = LayerNormMLP(
hidden_size,
ffn_hidden_size,
......@@ -342,6 +348,7 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
activation=activation,
)
self.hidden_dropout = hidden_dropout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment