Unverified Commit c473f0e6 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Adding Gated/Non-gated ReLU, Quick GeLU, Squared ReLU (#826)



* renamed gelu to act

* added relu, srelu, qgelu

* fixes initialization for layernorm_fp8_mlp tests

* moved activation_fp8 prim into testunit file

* Moved NVTE_Activation_Enum to common/.../activation.h

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 07291027
......@@ -19,7 +19,9 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.mlp import activation_lu, activation_lu_fp8, fused_layernorm_fp8_mlp
from transformer_engine.jax.mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions import act_lu_fp8, dact_lu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import dgated_act_lu_cast_transpose
GEMM_CASES = [
......@@ -34,10 +36,16 @@ LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available()
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == 'linear':
return lambda x: x
if fn_or_string == 'quick_gelu':
return lambda x: nn.gelu(x, approximate=True)
if fn_or_string == 'squared_relu':
return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
......@@ -171,14 +179,20 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(128, 256, 512),
@pytest.mark.parametrize('m,n,k', [(256, 128, 512),
(16384, 1024, 2816),
(16384, 2816, 1024),
(16384, 1024, 1024)])
@pytest.mark.parametrize('activation_type', [('gelu', ),
('gelu', 'linear'),
('silu', ),
('silu', 'linear')])
('silu', 'linear'),
('relu',),
('relu', 'linear'),
('quick_gelu',),
('quick_gelu', 'linear'),
('squared_relu',),
('squared_relu', 'linear')])
@pytest.mark.parametrize('use_bias', [True, False])
def test_grad_fused_layernorm_fp8_mlp(self, m, n, k,
activation_type: Sequence[Union[str, Callable]], use_bias: bool):
......@@ -187,8 +201,8 @@ class TestFP8Dot:
subkeys = jax.random.split(key, 6)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
......@@ -345,7 +359,13 @@ class TestActivationLu:
@pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'),
('silu',),
('silu', 'linear')])
('silu', 'linear'),
('relu',),
('relu', 'linear'),
('quick_gelu',),
('quick_gelu', 'linear'),
('squared_relu',),
('squared_relu', 'linear') ])
def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
......@@ -363,37 +383,74 @@ class TestActivationLu:
class TestActivationLuFP8(TestActivationLu):
def primitive_func(self, inputs, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv):
return jnp.mean(
activation_lu_fp8(inputs,
amax, scale, scale_inv,
jnp.float8_e4m3fn, jnp.float8_e5m2,
activation_type = self.activation_type))
def prim_func(self, x):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
activation_type = self.activation_type
@jax.custom_vjp
def _prim_func(x, _x_t, _dbias, _amax):
output = _prim_func_fwd(x, _x_t, _dbias, _amax)
return output
def _prim_func_fwd(x, _x_t, _dbias, _amax):
activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv,
FP8Helper.FWD_DTYPE, activation_type)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x)
return activation_lu_out, ctx
def _prim_func_bwd(ctx, g):
x = ctx
if len(self.activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \
dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv,
FP8Helper.BWD_DTYPE, -1, activation_type)
dbias = jnp.empty(x.shape[-1], x.dtype)
else: #not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE,
-1, -2, self.activation_type)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
return ctx
_prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
amax_no_use = jnp.zeros(1, jnp.float32)
value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d:
jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3))
return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'),
('silu',),
('silu', 'linear')])
('silu', 'linear'),
('relu',),
('relu', 'linear'),
('quick_gelu',),
('quick_gelu', 'linear'),
('squared_relu',),
('squared_relu', 'linear') ])
def test_activation_lu(self, random_inputs, activation_type):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
self.activation_type = activation_type
self.transpose_indices = (1, 2, 0)
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,)))
transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1)
dx_trans_no_use = jnp.zeros([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.zeros(x.shape[-1], dtype=x.dtype)
prim_out, (prim_grad, prim_grad_trans, dbias, amax, _, _) = \
value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use,
self.amax, self.scale, self.scale_inv)
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, (ref_grad,) = self.ref_func(x, activation_type)
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
......@@ -402,7 +459,7 @@ class TestActivationLuFP8(TestActivationLu):
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, transpose_indices),
jnp.transpose(ref_grad, self.transpose_indices),
dtype=FP8Helper.BWD_DTYPE)
......
......@@ -73,3 +73,26 @@ void nvte_dqgelu(const NVTETensor grad,
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_qgeglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dqgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -52,3 +52,48 @@ void nvte_dreglu(const NVTETensor grad,
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_srelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_srelu);
using namespace transformer_engine;
act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dsrelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_sreglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dsreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -8,23 +8,23 @@
#include "../util/math.h"
void nvte_swish(const NVTETensor input,
void nvte_silu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_swish);
NVTE_API_CALL(nvte_silu);
using namespace transformer_engine;
act_fn<fp32, Empty, swish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dswish(const NVTETensor grad,
void nvte_dsilu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswish);
NVTE_API_CALL(nvte_dsilu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dswish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
......@@ -35,7 +35,7 @@ void nvte_swiglu(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, swish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -46,7 +46,7 @@ void nvte_dswiglu(const NVTETensor grad,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, swish<fp32, fp32>, dswish<fp32, fp32>>(
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
......
......@@ -17,20 +17,52 @@
extern "C" {
#endif
/*! \brief Compute GELU activation of the input.
/* Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */
/*! \brief Compute activation of the input.
*
* \param[in] input Input tensor for GELU activation.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
enum class NVTE_Activation_Type {
GELU,
GEGLU,
SILU,
SWIGLU,
RELU,
REGLU,
QGELU,
QGEGLU,
SRELU,
SREGLU,
};
void nvte_gelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute GELU activation gradient.
void nvte_silu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_srelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for GELU activation.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
......@@ -39,135 +71,81 @@ void nvte_dgelu(const NVTETensor grad,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute GeGLU of the input.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes GELU(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_geglu(const NVTETensor input,
void nvte_dsilu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute GeGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dgeglu(const NVTETensor grad,
void nvte_drelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_dqgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute SiLU activation of the input.
*
* \param[in] input Input tensor for GELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_swish(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute Swish activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for Swish activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dswish(const NVTETensor grad,
void nvte_dsrelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute SwiGLU activation of the input.
/*! \brief Compute gated activation of the input.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Swish(input[N, :H]) x input[N, H:]
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_swiglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute SwiGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dswiglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute RELU activation of the input.
*
* \param[in] input Input tensor for RELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_reglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute RELU activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for RELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_drelu(const NVTETensor grad,
const NVTETensor input,
void nvte_qgeglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute ReGLU activation of the input.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes ReLU(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_reglu(const NVTETensor input,
void nvte_sreglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute ReGLU gradient.
/*! \brief Compute gated activation gradient.
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_dswiglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_dreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute QuickGELU activation of the input.
*
* \param[in] input Input tensor for QuickGELU activation.
* \param[in,out] output Output tensor. Approximates GELU as input x sigmoid(1.702 x input).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_dqgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute QuickGELU activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for QuickGELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgelu(const NVTETensor grad,
void nvte_dsreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
......
......@@ -90,35 +90,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute backward of GELU operation on the input, then cast and transpose. Additionally,
* reduce the result of the GELU backward along the first dimension.
*
* This function produces 3 results:
* - `cast_output` is equal to `cast(dGELU(input))`
* - `transposed_output` is equal to `transpose(cast(dGELU(input)))`
* - `dbias` is equal to `reduce(dGELU(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gelu_input Tensor used as input to the forward of GELU operation.
* Shape [N, H].
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[out] dbias Result of the reduction of the dGELU(input) along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor gelu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Cast and transpose multiple tensors.
*
* This function casts each input tensor and produces 2 results:
......@@ -140,38 +111,19 @@ void nvte_multi_cast_transpose(size_t num_tensors,
NVTETensor* transposed_output_list,
cudaStream_t stream);
/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output.
*
* This function produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] geglu_input Tensor used as input to the forward of GeGLU operation.
* Shape [N, H * 2].
* \param[in,out] cast_output Result of the cast. Shape: [N, H * 2].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor geglu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
/*! \brief Compute backward of SiLU operation on the input, then cast and transpose. Additionally,
/*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally,
* reduce the result of the SiLU backward along the first dimension.
*
* This function produces 3 results:
* - `cast_output` is equal to `cast(dSiLU(input))`
* - `transposed_output` is equal to `transpose(cast(dSiLU(input)))`
* - `dbias` is equal to `reduce(dSiLU(input), axis=0)`
* - `cast_output` is equal to `cast(dact(input))`
* - `transposed_output` is equal to `transpose(cast(dact(input)))`
* - `dbias` is equal to `reduce(dact(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] swish_input Tensor used as input to the forward of SiLU operation.
* \param[in] act_input Tensor used as input to the forward of SiLU operation.
* Shape [N, H].
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
......@@ -179,33 +131,97 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/
void nvte_cast_transpose_dbias_dswish(const NVTETensor input,
const NVTETensor swish_input,
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dswiglu of the input, additionally does cast and transpose the dswiglu output.
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output.
*
* This function produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] swiglu_input Tensor used as input to the forward of SwiGLU operation.
* \param[in] gated_act_input Tensor used as input to the forward of GeGLU operation.
* Shape [N, H * 2].
* \param[in,out] cast_output Result of the cast. Shape: [N, H * 2].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dswiglu_cast_transpose(const NVTETensor input,
const NVTETensor swiglu_input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dreglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dqgeglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dsreglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -124,7 +124,7 @@ struct CTDBiasDGeluParam {
using OutputType = OType;
using ComputeType = CType;
const IType *input;
const IType2 *gelu_input;
const IType2 *act_input;
OType *output_c;
OType *output_t;
const CType *scale_ptr;
......@@ -627,7 +627,7 @@ template <typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_dgelu_kernel(const Param param,
cast_transpose_dbias_dact_kernel(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
......@@ -655,7 +655,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType2 * const my_gelu_input_tile = param.gelu_input +
const IType2 * const my_act_input_tile = param.act_input +
(tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
......@@ -676,7 +676,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
IVec2 gelu_in[2][nvec_out];
IVec2 act_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
......@@ -696,7 +696,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
gelu_in[0][i].load_from(my_gelu_input_tile, current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
......@@ -708,22 +708,22 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
gelu_in[current_in][j].load_from(my_gelu_input_tile,
act_in[current_in][j].load_from(my_act_input_tile,
current_stride + my_place_in +
stride * (nvec_out + j));
}
}
CVec after_dgelu[nvec_out]; // NOLINT(*)
CVec after_dact[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = OP(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
after_dact[j].data.elt[k] = OP(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]);
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
cast_and_transpose_regs_partial_dbias<true>(after_dgelu, out_trans,
cast_and_transpose_regs_partial_dbias<true>(after_dact, out_trans,
partial_dbias, my_output_c_tile,
current_place, stride, max, scale,
(my_id_in_warp + i +
......@@ -789,7 +789,7 @@ template <typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
cast_transpose_dbias_dact_kernel_notaligned(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
......@@ -817,7 +817,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType2 * const my_gelu_input_tile = param.gelu_input +
const IType2 * const my_act_input_tile = param.act_input +
(tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
......@@ -847,7 +847,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
IVec2 gelu_in[2][nvec_out];
IVec2 act_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
......@@ -869,10 +869,10 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
gelu_in[0][i].load_from(my_gelu_input_tile, current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i);
} else {
in[0][i].clear();
gelu_in[0][i].clear();
act_in[0][i].clear();
}
}
}
......@@ -889,28 +889,28 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
gelu_in[current_in][j].load_from(my_gelu_input_tile,
act_in[current_in][j].load_from(my_act_input_tile,
current_stride + my_place_in +
stride * (nvec_out + j));
} else {
in[current_in][j].clear();
gelu_in[current_in][j].clear();
act_in[current_in][j].clear();
}
}
}
CVec after_dgelu[nvec_out]; // NOLINT(*)
CVec after_dact[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = OP(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
after_dact[j].data.elt[k] = OP(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]);
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height;
cast_and_transpose_regs_partial_dbias<false>(after_dgelu, out_trans,
cast_and_transpose_regs_partial_dbias<false>(after_dact, out_trans,
partial_dbias, my_output_c_tile,
current_place, stride, max, scale,
(my_id_in_warp + i +
......@@ -983,8 +983,8 @@ template <int nvec_in, int nvec_out,
CType (*OP2)(CType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
dgeglu_cast_transpose_kernel(const IType * const input,
const IType * const gelu_input,
dgated_act_cast_transpose_kernel(const IType * const input,
const IType * const act_input,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
......@@ -1011,10 +1011,10 @@ dgeglu_cast_transpose_kernel(const IType * const input,
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gelu_input_tile = gelu_input + (tile_id_x * nvec_in +
const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gate_input_tile = gelu_input + (tile_id_x * nvec_in +
const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in +
......@@ -1034,7 +1034,7 @@ dgeglu_cast_transpose_kernel(const IType * const input,
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec gelu_in[2][nvec_out];
IVec act_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
......@@ -1054,7 +1054,7 @@ dgeglu_cast_transpose_kernel(const IType * const input,
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
gelu_in[0][i].load_from(my_gelu_input_tile, current_stride2 + my_place + stride2 * i);
act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
}
#pragma unroll
......@@ -1067,27 +1067,27 @@ dgeglu_cast_transpose_kernel(const IType * const input,
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
gelu_in[current_in][j].load_from(my_gelu_input_tile,
act_in[current_in][j].load_from(my_act_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
}
}
CVec after_dgelu[nvec_out]; // NOLINT(*)
CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = OP1(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP2(gelu_in[current_in ^ 1][j].data.elt[k], {});
OP2(act_in[current_in ^ 1][j].data.elt[k], {});
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
cast_and_transpose_regs<true>(after_dgelu, out_trans_0, my_output_c_tile_0,
cast_and_transpose_regs<true>(after_dact, out_trans_0, my_output_c_tile_0,
current_place, stride2, max, scale, true);
OVec out_trans_1[nvec_in]; // NOLINT(*)
cast_and_transpose_regs<true>(after_dgate, out_trans_1, my_output_c_tile_1,
......@@ -1156,8 +1156,8 @@ template <int nvec_in, int nvec_out,
CType (*OP2)(CType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
const IType * const gelu_input,
dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
const IType * const act_input,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
......@@ -1185,10 +1185,10 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gelu_input_tile = gelu_input + (tile_id_x * nvec_in +
const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gate_input_tile = gelu_input + (tile_id_x * nvec_in +
const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in +
......@@ -1218,7 +1218,7 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec gelu_in[2][nvec_out];
IVec act_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
......@@ -1239,11 +1239,11 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
gelu_in[0][i].load_from(my_gelu_input_tile, current_stride2 + my_place + stride2 * i);
act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
} else {
in[0][i].clear();
gelu_in[0][i].clear();
act_in[0][i].clear();
gate_in[0][i].clear();
}
}
......@@ -1262,36 +1262,36 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
gelu_in[current_in][j].load_from(my_gelu_input_tile,
act_in[current_in][j].load_from(my_act_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
} else {
in[current_in][j].clear();
gelu_in[current_in][j].clear();
act_in[current_in][j].clear();
gate_in[current_in][j].clear();
}
}
}
}
CVec after_dgelu[nvec_out]; // NOLINT(*)
CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = OP1(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP2(gelu_in[current_in ^ 1][j].data.elt[k], {});
OP2(act_in[current_in ^ 1][j].data.elt[k], {});
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
OVec out_trans_1[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height;
cast_and_transpose_regs<false>(after_dgelu, out_trans_0, my_output_c_tile_0,
cast_and_transpose_regs<false>(after_dact, out_trans_0, my_output_c_tile_0,
current_place, stride2, max, scale, valid_store);
cast_and_transpose_regs<false>(after_dgate, out_trans_1, my_output_c_tile_1,
current_place, stride2, max, scale, valid_store);
......@@ -1360,8 +1360,8 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
template <typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
void cast_transpose_dbias_dgelu(const Tensor &input,
const Tensor &gelu_input,
void cast_transpose_dbias_dact(const Tensor &input,
const Tensor &act_input,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *dbias,
......@@ -1389,27 +1389,27 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
NVTE_CHECK(input.data.dtype == gelu_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == gelu_input.data.shape, "Shapes of both inputs must match.");
NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
using InputType2 = InputType;
/* dgelu fusion kernel uses more registers */
constexpr int desired_load_size_dgelu = 4;
constexpr int desired_store_size_dgelu = 4;
/* dact fusion kernel uses more registers */
constexpr int desired_load_size_dact = 4;
constexpr int desired_store_size_dact = 4;
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size_dgelu / itype_size;
constexpr int nvec_out = desired_store_size_dgelu / otype_size;
constexpr int nvec_in = desired_load_size_dact / itype_size;
constexpr int nvec_out = desired_store_size_dact / otype_size;
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
}
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckInputTensor(input, "cast_transpose_dbias_dact_input");
CheckInputTensor(act_input, "act_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
......@@ -1434,7 +1434,7 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
using Param = CTDBiasDGeluParam<InputType, InputType2, OutputType, ComputeType>;
Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.gelu_input = reinterpret_cast<const InputType2 *>(gelu_input.data.dptr);
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
......@@ -1443,23 +1443,23 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
if (full_tile) {
cudaFuncSetAttribute(
cast_transpose_dbias_dgelu_kernel<ComputeType, Empty,
cast_transpose_dbias_dact_kernel<ComputeType, Empty,
nvec_in, nvec_out, Param, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_dgelu_kernel<ComputeType, Empty,
cast_transpose_dbias_dact_kernel<ComputeType, Empty,
nvec_in, nvec_out, Param, OP>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel_notaligned<
cudaFuncSetAttribute(cast_transpose_dbias_dact_kernel_notaligned<
ComputeType, Empty,
nvec_in, nvec_out, Param, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_dgelu_kernel_notaligned<
cast_transpose_dbias_dact_kernel_notaligned<
ComputeType, Empty,
nvec_in, nvec_out, Param, OP>
<<<n_blocks,
......@@ -1476,32 +1476,32 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
template <typename ComputeType, typename ParamOP,
ComputeType (*OP1)(ComputeType, const ParamOP&),
ComputeType (*OP2)(ComputeType, const ParamOP&)>
void dgeglu_cast_transpose(const Tensor &input,
const Tensor &geglu_input,
void dgated_act_cast_transpose(const Tensor &input,
const Tensor &gated_act_input,
Tensor *cast_output,
Tensor *transposed_output,
cudaStream_t stream) {
CheckInputTensor(input, "dgeglu_cast_transpose_input");
CheckInputTensor(geglu_input, "dgeglu_cast_transpose_geglu_input");
CheckOutputTensor(*cast_output, "dgeglu_cast_transpose_cast_output");
CheckOutputTensor(*transposed_output, "dgeglu_cast_transpose_transposed_output");
CheckInputTensor(input, "dgated_act_cast_transpose_input");
CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input");
CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output");
CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(geglu_input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2,
"T output must have 2 dimensions.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(geglu_input.data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(geglu_input.data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(input.data.dtype == geglu_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
......@@ -1515,13 +1515,13 @@ void dgeglu_cast_transpose(const Tensor &input,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
using InputType2 = InputType;
/* dgelu fusion kernel uses more registers */
constexpr int desired_load_size_dgelu = 4;
constexpr int desired_store_size_dgelu = 4;
/* dact fusion kernel uses more registers */
constexpr int desired_load_size_dact = 4;
constexpr int desired_store_size_dact = 4;
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size_dgelu / itype_size;
constexpr int nvec_out = desired_store_size_dgelu / otype_size;
constexpr int nvec_in = desired_load_size_dact / itype_size;
constexpr int nvec_out = desired_store_size_dact / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
......@@ -1533,13 +1533,13 @@ void dgeglu_cast_transpose(const Tensor &input,
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
if (full_tile) {
cudaFuncSetAttribute(dgeglu_cast_transpose_kernel<
cudaFuncSetAttribute(dgated_act_cast_transpose_kernel<
nvec_in, nvec_out,
ComputeType, InputType, OutputType,
Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
dgeglu_cast_transpose_kernel< nvec_in, nvec_out,
dgated_act_cast_transpose_kernel< nvec_in, nvec_out,
ComputeType, InputType, OutputType, Empty, OP1, OP2>
<<<n_blocks,
cast_transpose_num_threads,
......@@ -1547,7 +1547,7 @@ void dgeglu_cast_transpose(const Tensor &input,
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(geglu_input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
......@@ -1555,13 +1555,13 @@ void dgeglu_cast_transpose(const Tensor &input,
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(dgeglu_cast_transpose_kernel_notaligned<
cudaFuncSetAttribute(dgated_act_cast_transpose_kernel_notaligned<
nvec_in, nvec_out,
ComputeType, InputType, OutputType,
Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
dgeglu_cast_transpose_kernel_notaligned<nvec_in, nvec_out,
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out,
ComputeType, InputType, OutputType, Empty, OP1, OP2>
<<<n_blocks,
cast_transpose_num_threads,
......@@ -1569,7 +1569,7 @@ void dgeglu_cast_transpose(const Tensor &input,
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(geglu_input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
......@@ -1600,7 +1600,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
}
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor gelu_input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
......@@ -1608,9 +1608,9 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine;
cast_transpose_dbias_dgelu<fp32, Empty, dgelu<fp32, fp32>>(
cast_transpose_dbias_dact<fp32, Empty, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input),
*reinterpret_cast<const Tensor*>(act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
......@@ -1619,32 +1619,32 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
}
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor geglu_input,
const NVTETensor gated_act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine;
dgeglu_cast_transpose<fp32, Empty, dgelu<fp32, fp32>, gelu<fp32, fp32>>(
dgated_act_cast_transpose<fp32, Empty, dgelu<fp32, fp32>, gelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(geglu_input),
*reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
void nvte_cast_transpose_dbias_dswish(const NVTETensor input,
const NVTETensor swish_input,
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input,
const NVTETensor silu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dswish);
NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu);
using namespace transformer_engine;
cast_transpose_dbias_dgelu<fp32, Empty, dswish<fp32, fp32>>(
cast_transpose_dbias_dact<fp32, Empty, dsilu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(swish_input),
*reinterpret_cast<const Tensor*>(silu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
......@@ -1659,10 +1659,112 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu_cast_transpose);
using namespace transformer_engine;
dgeglu_cast_transpose<fp32, Empty, dswish<fp32, fp32>, swish<fp32, fp32>>(
dgated_act_cast_transpose<fp32, Empty, dsilu<fp32, fp32>, silu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(swiglu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
const NVTETensor relu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_drelu);
using namespace transformer_engine;
cast_transpose_dbias_dact<fp32, Empty, drelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(relu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
}
void nvte_dreglu_cast_transpose(const NVTETensor input,
const NVTETensor gated_act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu_cast_transpose);
using namespace transformer_engine;
dgated_act_cast_transpose<fp32, Empty, drelu<fp32, fp32>, relu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
const NVTETensor srelu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu);
using namespace transformer_engine;
cast_transpose_dbias_dact<fp32, Empty, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(srelu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
}
void nvte_dsreglu_cast_transpose(const NVTETensor input,
const NVTETensor gated_act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu_cast_transpose);
using namespace transformer_engine;
dgated_act_cast_transpose<fp32, Empty, dsrelu<fp32, fp32>, srelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input,
const NVTETensor qgelu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu);
using namespace transformer_engine;
cast_transpose_dbias_dact<fp32, Empty, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(qgelu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
}
void nvte_dqgeglu_cast_transpose(const NVTETensor input,
const NVTETensor gated_act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu_cast_transpose);
using namespace transformer_engine;
dgated_act_cast_transpose<fp32, Empty, dqgelu<fp32, fp32>, qgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
......@@ -53,13 +53,13 @@ __device__ inline OType dqgelu(const IType val, const Empty& e) {
}
template <typename OType, typename IType>
__device__ inline OType swish(const IType val, const Empty& e) {
__device__ inline OType silu(const IType val, const Empty& e) {
const float cval = val;
return cval * sigmoid<float, float>(cval, e);
}
template <typename OType, typename IType>
__device__ inline OType dswish(const IType val, const Empty& e) {
__device__ inline OType dsilu(const IType val, const Empty& e) {
const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
}
......@@ -74,6 +74,15 @@ __device__ inline OType drelu(IType value, const Empty &) {
return value > 0.f ? 1.f : 0.f;
}
template <typename OType, typename IType>
__device__ inline OType srelu(IType value, const Empty &) {
return value > 0 ? value * value : 0.f;
}
template <typename OType, typename IType>
__device__ inline OType dsrelu(IType value, const Empty &) {
return fmaxf(2.f * value, 0.f);
}
} // namespace transformer_engine
......
......@@ -27,7 +27,7 @@ from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine_jax import NVTE_Activation_Enum
from transformer_engine_jax import NVTE_Activation_Type
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
......@@ -126,10 +126,16 @@ def _check_valid_batch_dims(bdims):
ActivationEnum = {
('gelu',): NVTE_Activation_Enum.GELU,
('gelu', 'linear'): NVTE_Activation_Enum.GEGLU,
('silu',): NVTE_Activation_Enum.SILU,
('silu', 'linear'): NVTE_Activation_Enum.SWIGLU
('gelu',): NVTE_Activation_Type.GELU,
('gelu', 'linear'): NVTE_Activation_Type.GEGLU,
('silu',): NVTE_Activation_Type.SILU,
('silu', 'linear'): NVTE_Activation_Type.SWIGLU,
('relu',): NVTE_Activation_Type.RELU,
('relu', 'linear'): NVTE_Activation_Type.REGLU,
('quick_gelu',): NVTE_Activation_Type.QGELU,
('quick_gelu', 'linear'): NVTE_Activation_Type.QGEGLU,
('squared_relu',): NVTE_Activation_Type.SRELU,
('squared_relu', 'linear'): NVTE_Activation_Type.SREGLU,
}
......@@ -2655,7 +2661,7 @@ class ActLuPrimitive(BasePrimitive):
"""
act_lu partitioning
"""
del result_infos
del result_infos, act_enum
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
......@@ -2721,6 +2727,7 @@ class DActLuPrimitive(BasePrimitive):
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
# assert ir_in_shape == gi_shape
for axis in range(len(ir_in_shape) - 1):
assert ir_in_shape[axis] == gi_shape[axis]
......@@ -2783,7 +2790,7 @@ class DActLuPrimitive(BasePrimitive):
"""
dact_lu partition
"""
del result_infos
del result_infos, act_enum
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
......
......@@ -9,6 +9,7 @@
#include <cublasLt.h>
#include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/activation.h"
#include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h"
#include "jax/csrc/utils.h"
......@@ -101,11 +102,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
pybind11::enum_<NVTE_Activation_Enum>(m, "NVTE_Activation_Enum", pybind11::module_local())
.value("GELU", NVTE_Activation_Enum::GELU)
.value("GEGLU", NVTE_Activation_Enum::GEGLU)
.value("SILU", NVTE_Activation_Enum::SILU)
.value("SWIGLU", NVTE_Activation_Enum::SWIGLU);
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU)
.value("GEGLU", NVTE_Activation_Type::GEGLU)
.value("SILU", NVTE_Activation_Type::SILU)
.value("SWIGLU", NVTE_Activation_Type::SWIGLU)
.value("RELU", NVTE_Activation_Type::RELU)
.value("REGLU", NVTE_Activation_Type::REGLU)
.value("QGELU", NVTE_Activation_Type::QGELU)
.value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
......
......@@ -37,12 +37,18 @@ std::vector<size_t> MakeShapeVector(NVTEShape shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
size_t get_activation_len(NVTE_Activation_Enum act_enum) {
switch (act_enum) {
case NVTE_Activation_Enum::GELU: return 1;
case NVTE_Activation_Enum::GEGLU: return 2;
case NVTE_Activation_Enum::SILU: return 1;
case NVTE_Activation_Enum::SWIGLU: return 2;
size_t get_activation_len(NVTE_Activation_Type activation_enum) {
switch (activation_enum) {
case NVTE_Activation_Type::GELU: return 1;
case NVTE_Activation_Type::GEGLU: return 2;
case NVTE_Activation_Type::SILU: return 1;
case NVTE_Activation_Type::SWIGLU: return 2;
case NVTE_Activation_Type::RELU: return 1;
case NVTE_Activation_Type::REGLU: return 2;
case NVTE_Activation_Type::QGELU: return 1;
case NVTE_Activation_Type::QGEGLU: return 2;
case NVTE_Activation_Type::SRELU: return 1;
case NVTE_Activation_Type::SREGLU: return 2;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
......@@ -188,7 +194,7 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output,
NVTE_Activation_Enum act_enum) {
NVTE_Activation_Type act_enum) {
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n};
......@@ -198,18 +204,36 @@ void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype,
static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
switch (act_enum) {
case NVTE_Activation_Enum::GELU:
case NVTE_Activation_Type::GELU:
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::GEGLU:
case NVTE_Activation_Type::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SILU:
nvte_swish(input_tensor.data(), output_tensor.data(), stream);
case NVTE_Activation_Type::SILU:
nvte_silu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SWIGLU:
case NVTE_Activation_Type::SWIGLU:
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_relu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
......@@ -223,7 +247,7 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
nullptr, nullptr, output, act_enum);
......@@ -246,7 +270,7 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
scale_inv, amax_out, output, act_enum);
......@@ -260,7 +284,7 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n};
......@@ -272,22 +296,46 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
switch (act_enum) {
case NVTE_Activation_Enum::GELU:
case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::GEGLU:
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SILU:
nvte_dswish(input_tensor.data(), act_input_tensor.data(),
case NVTE_Activation_Type::SILU:
nvte_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SWIGLU:
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
......@@ -341,7 +389,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
......@@ -359,18 +407,33 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
switch (act_enum) {
case NVTE_Activation_Enum::GELU:
case NVTE_Activation_Type::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Enum::SILU:
nvte_cast_transpose_dbias_dswish(input_tensor.data(), act_input_tensor.data(),
case NVTE_Activation_Type::SILU:
nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
default:
throw std::runtime_error("Activation Type is not Implemented in DActLuDBiasCastTranspose");
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
......@@ -395,7 +458,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
......@@ -409,16 +472,31 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
switch (act_enum) {
case NVTE_Activation_Enum::GEGLU:
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Enum::SWIGLU:
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
......
......@@ -18,6 +18,7 @@
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/activation.h>
#include "common/util/logging.h"
namespace transformer_engine {
......@@ -25,6 +26,8 @@ namespace jax {
constexpr int kMaxNumDim = 8;
size_t get_activation_len(NVTE_Activation_Type activation_enum);
struct Shape {
int num_dim;
size_t dims[kMaxNumDim];
......
......@@ -944,9 +944,15 @@ class LayerNormMLP(TransformerEngineBase):
) and not self.return_layernorm_output and self.enable_layernorm
gated_act_pool = [('gelu', 'linear'),
('silu', 'linear')]
('silu', 'linear'),
('relu', 'linear'),
('quick_gelu', 'linear'),
('squared_relu', 'linear')]
act_pool = [('gelu',),
('silu',)]
('silu',),
('relu',),
('quick_gelu',),
('squared_relu',)]
normalize_acts = []
for act in self.activations:
if not isinstance(act, str):
......
......@@ -15,7 +15,7 @@ from .cpp_extensions import act_lu, act_lu_fp8, dact_lu
from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
......@@ -56,72 +56,6 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]):
"""
Activation Unit
"""
transpose_indices = (1, 2, 0)
dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv, fwd_dtype,
bwd_dtype, activation_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]):
output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv,
fwd_dtype, bwd_dtype, activation_type)
return output
def _activation_lu_fp8_fwd_rule(x,
dx_trans_no_use, # pylint: disable=unused-argument
dbias_no_use, # pylint: disable=unused-argument
amax,
scale, scale_inv,
fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
activation_type):
activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv, fwd_dtype,
activation_type)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x, amax, scale, scale_inv)
return activation_lu_out, ctx
def _activation_lu_fp8_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
activation_type,
ctx,
g):
x, amax, scale, scale_inv = ctx
if len(activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \
dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype, -1, activation_type)
dbias = jnp.empty(x.shape[-1], x.dtype)
else: #not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype,
-1, -2, activation_type)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
return ctx
_activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule)
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
......@@ -231,6 +165,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment