Unverified Commit c67bb2fc authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Adding other activation types to LayerNormMLP (#265)



* Added ReLU and GLU variants to common
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* pyTorch changes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* PyTorch C++ lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix storage errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Compute bgrad
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix numerical tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX export tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent df6f347f
...@@ -167,14 +167,73 @@ at::Tensor fp8_transpose(at::Tensor input, ...@@ -167,14 +167,73 @@ at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype transformer_engine::DType otype
); );
/***************************************************************************************************
* Activations
**************************************************************************************************/
at::Tensor fp8_gelu(at::Tensor input, at::Tensor gelu(at::Tensor input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype transformer_engine::DType otype
); );
at::Tensor relu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor geglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor reglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor swiglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor dgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor drelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor dgeglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor dreglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor dswiglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x, const at::Tensor &x,
......
...@@ -47,17 +47,177 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, ...@@ -47,17 +47,177 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input,
return output; return output;
} }
at::Tensor fp8_gelu_ts(at::Tensor input, at::Tensor gelu_ts(at::Tensor input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
int64_t fp8_tensor, int64_t fp8_tensor,
int64_t otype) { int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = fp8_gelu(input,
scale[fp8_tensor], at::Tensor s, a, s_inv;
amax[0][fp8_tensor], if (scale.numel()) {
scale_inv[fp8_tensor], s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = gelu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor relu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = relu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor reglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = reglu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor geglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = geglu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor swiglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = swiglu(input,
s,
a,
s_inv,
otype_arg); otype_arg);
return output; return output;
} }
...@@ -171,7 +331,11 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, ...@@ -171,7 +331,11 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
TORCH_LIBRARY(tex_ts, m) { TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts); m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts); m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("fp8_gelu_ts", &fp8_gelu_ts); m.def("gelu_ts", &gelu_ts);
m.def("relu_ts", &relu_ts);
m.def("geglu_ts", &geglu_ts);
m.def("reglu_ts", &reglu_ts);
m.def("swiglu_ts", &swiglu_ts);
m.def("te_gemm_ts", &te_gemm_ts); m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
......
...@@ -254,7 +254,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ...@@ -254,7 +254,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER) register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER) register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::fp8_gelu_ts', onnx_fp8_gelu, VER) register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER) register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER) register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER) register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
...@@ -136,6 +136,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -136,6 +136,9 @@ class TransformerLayer(torch.nn.Module):
using :attr:`fuse_qkv_params=False`. using :attr:`fuse_qkv_params=False`.
bias : bool, default = `True` bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases. if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -214,6 +217,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -214,6 +217,7 @@ class TransformerLayer(torch.nn.Module):
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False, ub_tp_comm_overlap: bool = False,
bias: bool = True, bias: bool = True,
activation: str = 'gelu'
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -316,9 +320,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -316,9 +320,11 @@ class TransformerLayer(torch.nn.Module):
bias=bias, bias=bias,
) )
# LayerNorm -> gelu(Linear + Bias) -> Linear # LayerNorm -> activation(Linear + Bias) -> Linear
# parallel_mode not supported for LayerNormMLP, # parallel_mode not supported for LayerNormMLP,
# FC1 is CPL and FC2 is RPL # FC1 is CPL and FC2 is RPL
# In the case of GLU activation, FC1 handles both
# Linear layers before the activation
self.layernorm_mlp = LayerNormMLP( self.layernorm_mlp = LayerNormMLP(
hidden_size, hidden_size,
ffn_hidden_size, ffn_hidden_size,
...@@ -342,6 +348,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -342,6 +348,7 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs, ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
activation=activation,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment