Unverified Commit c473f0e6 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Adding Gated/Non-gated ReLU, Quick GeLU, Squared ReLU (#826)



* renamed gelu to act

* added relu, srelu, qgelu

* fixes initialization for layernorm_fp8_mlp tests

* moved activation_fp8 prim into testunit file

* Moved NVTE_Activation_Enum to common/.../activation.h

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 07291027
...@@ -19,7 +19,9 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti ...@@ -19,7 +19,9 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.mlp import activation_lu, activation_lu_fp8, fused_layernorm_fp8_mlp from transformer_engine.jax.mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions import act_lu_fp8, dact_lu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import dgated_act_lu_cast_transpose
GEMM_CASES = [ GEMM_CASES = [
...@@ -34,10 +36,16 @@ LN_CASES = [(512, 1024)] ...@@ -34,10 +36,16 @@ LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32] DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
def _convert_to_activation_function(fn_or_string): def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function.""" """Convert a string to an activation function."""
if fn_or_string == 'linear': if fn_or_string == 'linear':
return lambda x: x return lambda x: x
if fn_or_string == 'quick_gelu':
return lambda x: nn.gelu(x, approximate=True)
if fn_or_string == 'squared_relu':
return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
if isinstance(fn_or_string, str): if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string) return getattr(nn, fn_or_string)
if callable(fn_or_string): if callable(fn_or_string):
...@@ -171,14 +179,20 @@ class TestFP8Dot: ...@@ -171,14 +179,20 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(128, 256, 512), @pytest.mark.parametrize('m,n,k', [(256, 128, 512),
(16384, 1024, 2816), (16384, 1024, 2816),
(16384, 2816, 1024), (16384, 2816, 1024),
(16384, 1024, 1024)]) (16384, 1024, 1024)])
@pytest.mark.parametrize('activation_type', [('gelu', ), @pytest.mark.parametrize('activation_type', [('gelu', ),
('gelu', 'linear'), ('gelu', 'linear'),
('silu', ), ('silu', ),
('silu', 'linear')]) ('silu', 'linear'),
('relu',),
('relu', 'linear'),
('quick_gelu',),
('quick_gelu', 'linear'),
('squared_relu',),
('squared_relu', 'linear')])
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize('use_bias', [True, False])
def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, def test_grad_fused_layernorm_fp8_mlp(self, m, n, k,
activation_type: Sequence[Union[str, Callable]], use_bias: bool): activation_type: Sequence[Union[str, Callable]], use_bias: bool):
...@@ -187,8 +201,8 @@ class TestFP8Dot: ...@@ -187,8 +201,8 @@ class TestFP8Dot:
subkeys = jax.random.split(key, 6) subkeys = jax.random.split(key, 6)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
if use_bias: if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16) b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
...@@ -345,7 +359,13 @@ class TestActivationLu: ...@@ -345,7 +359,13 @@ class TestActivationLu:
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'), ('gelu', 'linear'),
('silu',), ('silu',),
('silu', 'linear')]) ('silu', 'linear'),
('relu',),
('relu', 'linear'),
('quick_gelu',),
('quick_gelu', 'linear'),
('squared_relu',),
('squared_relu', 'linear') ])
def test_activation_lu(self, random_inputs, activation_type): def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1) x = jnp.repeat(x, len(activation_type), axis=1)
...@@ -363,37 +383,74 @@ class TestActivationLu: ...@@ -363,37 +383,74 @@ class TestActivationLu:
class TestActivationLuFP8(TestActivationLu): class TestActivationLuFP8(TestActivationLu):
def primitive_func(self, inputs, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv): def prim_func(self, x):
return jnp.mean( amax = self.amax
activation_lu_fp8(inputs, scale = self.scale
amax, scale, scale_inv, scale_inv = self.scale_inv
jnp.float8_e4m3fn, jnp.float8_e5m2, activation_type = self.activation_type
activation_type = self.activation_type))
@jax.custom_vjp
def _prim_func(x, _x_t, _dbias, _amax):
output = _prim_func_fwd(x, _x_t, _dbias, _amax)
return output
def _prim_func_fwd(x, _x_t, _dbias, _amax):
activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv,
FP8Helper.FWD_DTYPE, activation_type)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x)
return activation_lu_out, ctx
def _prim_func_bwd(ctx, g):
x = ctx
if len(self.activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \
dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv,
FP8Helper.BWD_DTYPE, -1, activation_type)
dbias = jnp.empty(x.shape[-1], x.dtype)
else: #not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE,
-1, -2, self.activation_type)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
return ctx
_prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
amax_no_use = jnp.zeros(1, jnp.float32)
value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d:
jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3))
return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)]) @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'), ('gelu', 'linear'),
('silu',), ('silu',),
('silu', 'linear')]) ('silu', 'linear'),
('relu',),
('relu', 'linear'),
('quick_gelu',),
('quick_gelu', 'linear'),
('squared_relu',),
('squared_relu', 'linear') ])
def test_activation_lu(self, random_inputs, activation_type): def test_activation_lu(self, random_inputs, activation_type):
self.amax = jnp.zeros(1, jnp.float32) self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32) self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32) self.scale_inv = jnp.ones(1, jnp.float32)
self.activation_type = activation_type self.activation_type = activation_type
self.transpose_indices = (1, 2, 0)
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1) x = jnp.repeat(x, len(activation_type), axis=1)
value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,)))
transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1)
dx_trans_no_use = jnp.zeros([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.zeros(x.shape[-1], dtype=x.dtype)
prim_out, (prim_grad, prim_grad_trans, dbias, amax, _, _) = \ prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use,
self.amax, self.scale, self.scale_inv)
ref_out, (ref_grad,) = self.ref_func(x, activation_type) ref_out, (ref_grad,) = self.ref_func(x, activation_type)
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
...@@ -402,7 +459,7 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -402,7 +459,7 @@ class TestActivationLuFP8(TestActivationLu):
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1)))) assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans, assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, transpose_indices), jnp.transpose(ref_grad, self.transpose_indices),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE)
......
...@@ -73,3 +73,26 @@ void nvte_dqgelu(const NVTETensor grad, ...@@ -73,3 +73,26 @@ void nvte_dqgelu(const NVTETensor grad,
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
} }
void nvte_qgeglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dqgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
...@@ -52,3 +52,48 @@ void nvte_dreglu(const NVTETensor grad, ...@@ -52,3 +52,48 @@ void nvte_dreglu(const NVTETensor grad,
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
} }
void nvte_srelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_srelu);
using namespace transformer_engine;
act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dsrelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_sreglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dsreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
...@@ -8,23 +8,23 @@ ...@@ -8,23 +8,23 @@
#include "../util/math.h" #include "../util/math.h"
void nvte_swish(const NVTETensor input, void nvte_silu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_swish); NVTE_API_CALL(nvte_silu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, swish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
} }
void nvte_dswish(const NVTETensor grad, void nvte_dsilu(const NVTETensor grad,
const NVTETensor input, const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dswish); NVTE_API_CALL(nvte_dsilu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, dswish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, dsilu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
...@@ -35,7 +35,7 @@ void nvte_swiglu(const NVTETensor input, ...@@ -35,7 +35,7 @@ void nvte_swiglu(const NVTETensor input,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu); NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, swish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
} }
...@@ -46,7 +46,7 @@ void nvte_dswiglu(const NVTETensor grad, ...@@ -46,7 +46,7 @@ void nvte_dswiglu(const NVTETensor grad,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu); NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, swish<fp32, fp32>, dswish<fp32, fp32>>( dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
......
...@@ -17,20 +17,52 @@ ...@@ -17,20 +17,52 @@
extern "C" { extern "C" {
#endif #endif
/*! \brief Compute GELU activation of the input. /* Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */
/*! \brief Compute activation of the input.
* *
* \param[in] input Input tensor for GELU activation. * \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor. * \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
enum class NVTE_Activation_Type {
GELU,
GEGLU,
SILU,
SWIGLU,
RELU,
REGLU,
QGELU,
QGEGLU,
SRELU,
SREGLU,
};
void nvte_gelu(const NVTETensor input, void nvte_gelu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute GELU activation gradient. void nvte_silu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_srelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute activation gradient.
* *
* \param[in] grad Incoming gradient. * \param[in] grad Incoming gradient.
* \param[in] input Input tensor for GELU activation. * \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor. * \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
...@@ -39,135 +71,81 @@ void nvte_dgelu(const NVTETensor grad, ...@@ -39,135 +71,81 @@ void nvte_dgelu(const NVTETensor grad,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute GeGLU of the input. void nvte_dsilu(const NVTETensor grad,
* const NVTETensor input,
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes GELU(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_geglu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute GeGLU gradient. void nvte_drelu(const NVTETensor grad,
* \param[in] grad Incoming gradient of shape [N, H]. const NVTETensor input,
* \param[in] input Forward input tensor of shape [N, H * 2]. NVTETensor output,
* \param[in,out] output Outgoing gradient of shape [N, H * 2]. cudaStream_t stream);
* \param[in] stream CUDA stream used for the operation.
*/ void nvte_dqgelu(const NVTETensor grad,
void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input, const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute SiLU activation of the input. void nvte_dsrelu(const NVTETensor grad,
*
* \param[in] input Input tensor for GELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_swish(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute Swish activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for Swish activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dswish(const NVTETensor grad,
const NVTETensor input, const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute SwiGLU activation of the input.
/*! \brief Compute gated activation of the input.
* *
* \param[in] input Input tensor of shape [N, H * 2]. * \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H]. * \param[in,out] output Output tensor of shape [N, H].
* It computes Swish(input[N, :H]) x input[N, H:] * It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_swiglu(const NVTETensor input, void nvte_swiglu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute SwiGLU gradient. void nvte_reglu(const NVTETensor input,
* \param[in] grad Incoming gradient of shape [N, H]. NVTETensor output,
* \param[in] input Forward input tensor of shape [N, H * 2]. cudaStream_t stream);
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dswiglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute RELU activation of the input.
*
* \param[in] input Input tensor for RELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute RELU activation gradient. void nvte_qgeglu(const NVTETensor input,
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for RELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_drelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute ReGLU activation of the input. void nvte_sreglu(const NVTETensor input,
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes ReLU(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_reglu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute ReGLU gradient. /*! \brief Compute gated activation gradient.
* \param[in] grad Incoming gradient of shape [N, H]. * \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2]. * \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2]. * \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_dswiglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
void nvte_dreglu(const NVTETensor grad, void nvte_dreglu(const NVTETensor grad,
const NVTETensor input, const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute QuickGELU activation of the input. void nvte_dqgeglu(const NVTETensor grad,
* const NVTETensor input,
* \param[in] input Input tensor for QuickGELU activation. NVTETensor output,
* \param[in,out] output Output tensor. Approximates GELU as input x sigmoid(1.702 x input). cudaStream_t stream);
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute QuickGELU activation gradient. void nvte_dsreglu(const NVTETensor grad,
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for QuickGELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgelu(const NVTETensor grad,
const NVTETensor input, const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
......
...@@ -90,35 +90,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, ...@@ -90,35 +90,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute backward of GELU operation on the input, then cast and transpose. Additionally,
* reduce the result of the GELU backward along the first dimension.
*
* This function produces 3 results:
* - `cast_output` is equal to `cast(dGELU(input))`
* - `transposed_output` is equal to `transpose(cast(dGELU(input)))`
* - `dbias` is equal to `reduce(dGELU(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gelu_input Tensor used as input to the forward of GELU operation.
* Shape [N, H].
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[out] dbias Result of the reduction of the dGELU(input) along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor gelu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Cast and transpose multiple tensors. /*! \brief Cast and transpose multiple tensors.
* *
* This function casts each input tensor and produces 2 results: * This function casts each input tensor and produces 2 results:
...@@ -140,38 +111,19 @@ void nvte_multi_cast_transpose(size_t num_tensors, ...@@ -140,38 +111,19 @@ void nvte_multi_cast_transpose(size_t num_tensors,
NVTETensor* transposed_output_list, NVTETensor* transposed_output_list,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output. /*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally,
*
* This function produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] geglu_input Tensor used as input to the forward of GeGLU operation.
* Shape [N, H * 2].
* \param[in,out] cast_output Result of the cast. Shape: [N, H * 2].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor geglu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
/*! \brief Compute backward of SiLU operation on the input, then cast and transpose. Additionally,
* reduce the result of the SiLU backward along the first dimension. * reduce the result of the SiLU backward along the first dimension.
* *
* This function produces 3 results: * This function produces 3 results:
* - `cast_output` is equal to `cast(dSiLU(input))` * - `cast_output` is equal to `cast(dact(input))`
* - `transposed_output` is equal to `transpose(cast(dSiLU(input)))` * - `transposed_output` is equal to `transpose(cast(dact(input)))`
* - `dbias` is equal to `reduce(dSiLU(input), axis=0)` * - `dbias` is equal to `reduce(dact(input), axis=0)`
* *
* Calling this function with workspace being an empty tensor will not perform the operation, * Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values. * but instead set the shape and type of the workspace tensor to the required values.
* *
* \param[in] input Input tensor of shape [N, H]. * \param[in] input Input tensor of shape [N, H].
* \param[in] swish_input Tensor used as input to the forward of SiLU operation. * \param[in] act_input Tensor used as input to the forward of SiLU operation.
* Shape [N, H]. * Shape [N, H].
* \param[in,out] cast_output Result of the cast. Shape: [N, H]. * \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
...@@ -179,33 +131,97 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, ...@@ -179,33 +131,97 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
* first dimension. Shape: [H]. * first dimension. Shape: [H].
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/ */
void nvte_cast_transpose_dbias_dswish(const NVTETensor input,
const NVTETensor swish_input, void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor dbias,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute dswiglu of the input, additionally does cast and transpose the dswiglu output. void nvte_cast_transpose_dbias_dsilu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output.
* *
* This function produces 2 results: * This function produces 2 results:
* - `cast_output` is the result of the cast * - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast. * - `transposed_output` is the transposed result of the cast.
* *
* \param[in] input Input tensor of shape [N, H]. * \param[in] input Input tensor of shape [N, H].
* \param[in] swiglu_input Tensor used as input to the forward of SwiGLU operation. * \param[in] gated_act_input Tensor used as input to the forward of GeGLU operation.
* Shape [N, H * 2]. * Shape [N, H * 2].
* \param[in,out] cast_output Result of the cast. Shape: [N, H * 2]. * \param[in,out] cast_output Result of the cast. Shape: [N, H * 2].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N]. * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dswiglu_cast_transpose(const NVTETensor input, void nvte_dswiglu_cast_transpose(const NVTETensor input,
const NVTETensor swiglu_input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream); cudaStream_t stream);
void nvte_dreglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dqgeglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dsreglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -53,13 +53,13 @@ __device__ inline OType dqgelu(const IType val, const Empty& e) { ...@@ -53,13 +53,13 @@ __device__ inline OType dqgelu(const IType val, const Empty& e) {
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType swish(const IType val, const Empty& e) { __device__ inline OType silu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * sigmoid<float, float>(cval, e); return cval * sigmoid<float, float>(cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dswish(const IType val, const Empty& e) { __device__ inline OType dsilu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e); return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
} }
...@@ -74,6 +74,15 @@ __device__ inline OType drelu(IType value, const Empty &) { ...@@ -74,6 +74,15 @@ __device__ inline OType drelu(IType value, const Empty &) {
return value > 0.f ? 1.f : 0.f; return value > 0.f ? 1.f : 0.f;
} }
template <typename OType, typename IType>
__device__ inline OType srelu(IType value, const Empty &) {
return value > 0 ? value * value : 0.f;
}
template <typename OType, typename IType>
__device__ inline OType dsrelu(IType value, const Empty &) {
return fmaxf(2.f * value, 0.f);
}
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -27,7 +27,7 @@ from transformer_engine_jax import NVTE_Bias_Type ...@@ -27,7 +27,7 @@ from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine_jax import NVTE_Activation_Enum from transformer_engine_jax import NVTE_Activation_Type
from .sharding import all_reduce_max_along_all_axes_except_PP from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp from .sharding import all_reduce_sum_along_dp_fsdp
...@@ -126,10 +126,16 @@ def _check_valid_batch_dims(bdims): ...@@ -126,10 +126,16 @@ def _check_valid_batch_dims(bdims):
ActivationEnum = { ActivationEnum = {
('gelu',): NVTE_Activation_Enum.GELU, ('gelu',): NVTE_Activation_Type.GELU,
('gelu', 'linear'): NVTE_Activation_Enum.GEGLU, ('gelu', 'linear'): NVTE_Activation_Type.GEGLU,
('silu',): NVTE_Activation_Enum.SILU, ('silu',): NVTE_Activation_Type.SILU,
('silu', 'linear'): NVTE_Activation_Enum.SWIGLU ('silu', 'linear'): NVTE_Activation_Type.SWIGLU,
('relu',): NVTE_Activation_Type.RELU,
('relu', 'linear'): NVTE_Activation_Type.REGLU,
('quick_gelu',): NVTE_Activation_Type.QGELU,
('quick_gelu', 'linear'): NVTE_Activation_Type.QGEGLU,
('squared_relu',): NVTE_Activation_Type.SRELU,
('squared_relu', 'linear'): NVTE_Activation_Type.SREGLU,
} }
...@@ -2655,7 +2661,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -2655,7 +2661,7 @@ class ActLuPrimitive(BasePrimitive):
""" """
act_lu partitioning act_lu partitioning
""" """
del result_infos del result_infos, act_enum
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
...@@ -2721,6 +2727,7 @@ class DActLuPrimitive(BasePrimitive): ...@@ -2721,6 +2727,7 @@ class DActLuPrimitive(BasePrimitive):
ir_in_shape = ir_in_type.shape ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type) gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape gi_shape = gi_type.shape
# assert ir_in_shape == gi_shape
for axis in range(len(ir_in_shape) - 1): for axis in range(len(ir_in_shape) - 1):
assert ir_in_shape[axis] == gi_shape[axis] assert ir_in_shape[axis] == gi_shape[axis]
...@@ -2783,7 +2790,7 @@ class DActLuPrimitive(BasePrimitive): ...@@ -2783,7 +2790,7 @@ class DActLuPrimitive(BasePrimitive):
""" """
dact_lu partition dact_lu partition
""" """
del result_infos del result_infos, act_enum
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding out_shardings = dx_sharding
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cublasLt.h> #include <cublasLt.h>
#include "common/include/transformer_engine/fused_attn.h" #include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/activation.h"
#include "common/include/transformer_engine/transformer_engine.h" #include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h" #include "jax/csrc/modules.h"
#include "jax/csrc/utils.h" #include "jax/csrc/utils.h"
...@@ -101,11 +102,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -101,11 +102,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
pybind11::enum_<NVTE_Activation_Enum>(m, "NVTE_Activation_Enum", pybind11::module_local()) pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Enum::GELU) .value("GELU", NVTE_Activation_Type::GELU)
.value("GEGLU", NVTE_Activation_Enum::GEGLU) .value("GEGLU", NVTE_Activation_Type::GEGLU)
.value("SILU", NVTE_Activation_Enum::SILU) .value("SILU", NVTE_Activation_Type::SILU)
.value("SWIGLU", NVTE_Activation_Enum::SWIGLU); .value("SWIGLU", NVTE_Activation_Type::SWIGLU)
.value("RELU", NVTE_Activation_Type::RELU)
.value("REGLU", NVTE_Activation_Type::REGLU)
.value("QGELU", NVTE_Activation_Type::QGELU)
.value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
......
...@@ -37,12 +37,18 @@ std::vector<size_t> MakeShapeVector(NVTEShape shape) { ...@@ -37,12 +37,18 @@ std::vector<size_t> MakeShapeVector(NVTEShape shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim); return std::vector<size_t>(shape.data, shape.data + shape.ndim);
} }
size_t get_activation_len(NVTE_Activation_Enum act_enum) { size_t get_activation_len(NVTE_Activation_Type activation_enum) {
switch (act_enum) { switch (activation_enum) {
case NVTE_Activation_Enum::GELU: return 1; case NVTE_Activation_Type::GELU: return 1;
case NVTE_Activation_Enum::GEGLU: return 2; case NVTE_Activation_Type::GEGLU: return 2;
case NVTE_Activation_Enum::SILU: return 1; case NVTE_Activation_Type::SILU: return 1;
case NVTE_Activation_Enum::SWIGLU: return 2; case NVTE_Activation_Type::SWIGLU: return 2;
case NVTE_Activation_Type::RELU: return 1;
case NVTE_Activation_Type::REGLU: return 2;
case NVTE_Activation_Type::QGELU: return 1;
case NVTE_Activation_Type::QGEGLU: return 2;
case NVTE_Activation_Type::SRELU: return 1;
case NVTE_Activation_Type::SREGLU: return 2;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
break; break;
...@@ -188,7 +194,7 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size ...@@ -188,7 +194,7 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output, cudaStream_t stream, float *scale_inverse, float *amax, void *output,
NVTE_Activation_Enum act_enum) { NVTE_Activation_Type act_enum) {
auto act_len = get_activation_len(act_enum); auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n * act_len}; auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
...@@ -198,18 +204,36 @@ void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, ...@@ -198,18 +204,36 @@ void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype,
static_cast<DType>(out_dtype), amax, static_cast<DType>(out_dtype), amax,
scale, scale_inverse); scale, scale_inverse);
switch (act_enum) { switch (act_enum) {
case NVTE_Activation_Enum::GELU: case NVTE_Activation_Type::GELU:
nvte_gelu(input_tensor.data(), output_tensor.data(), stream); nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Enum::GEGLU: case NVTE_Activation_Type::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream); nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Enum::SILU: case NVTE_Activation_Type::SILU:
nvte_swish(input_tensor.data(), output_tensor.data(), stream); nvte_silu(input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Enum::SWIGLU: case NVTE_Activation_Type::SWIGLU:
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream); nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::RELU:
nvte_relu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
break; break;
...@@ -223,7 +247,7 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu ...@@ -223,7 +247,7 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
nullptr, nullptr, output, act_enum); nullptr, nullptr, output, act_enum);
...@@ -246,7 +270,7 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -246,7 +270,7 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
} }
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
scale_inv, amax_out, output, act_enum); scale_inv, amax_out, output, act_enum);
...@@ -260,7 +284,7 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq ...@@ -260,7 +284,7 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto act_len = get_activation_len(act_enum); auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
...@@ -272,22 +296,46 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq ...@@ -272,22 +296,46 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
switch (act_enum) { switch (act_enum) {
case NVTE_Activation_Enum::GELU: case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(), nvte_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream); output_tensor.data(), stream);
break; break;
case NVTE_Activation_Enum::GEGLU: case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), nvte_dgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream); output_tensor.data(), stream);
break; break;
case NVTE_Activation_Enum::SILU: case NVTE_Activation_Type::SILU:
nvte_dswish(input_tensor.data(), act_input_tensor.data(), nvte_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream); output_tensor.data(), stream);
break; break;
case NVTE_Activation_Enum::SWIGLU: case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), nvte_dswiglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream); output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
break; break;
...@@ -341,7 +389,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -341,7 +389,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
} }
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n}; auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
...@@ -359,18 +407,33 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -359,18 +407,33 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
switch (act_enum) { switch (act_enum) {
case NVTE_Activation_Enum::GELU: case NVTE_Activation_Type::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream); dbias_tensor.data(), workspace.data(), stream);
break; break;
case NVTE_Activation_Enum::SILU: case NVTE_Activation_Type::SILU:
nvte_cast_transpose_dbias_dswish(input_tensor.data(), act_input_tensor.data(), nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream); dbias_tensor.data(), workspace.data(), stream);
break; break;
case NVTE_Activation_Type::RELU:
nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
default: default:
throw std::runtime_error("Activation Type is not Implemented in DActLuDBiasCastTranspose"); NVTE_ERROR("Unsupported ActivationEnum");
break; break;
} }
} }
...@@ -395,7 +458,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -395,7 +458,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
} }
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto input_shape = desc.shape.to_vector(); auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2}; auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2}; auto output_shape = std::vector<size_t>{m, n * 2};
...@@ -409,16 +472,31 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -409,16 +472,31 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
switch (act_enum) { switch (act_enum) {
case NVTE_Activation_Enum::GEGLU: case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
stream); stream);
break; break;
case NVTE_Activation_Enum::SWIGLU: case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
stream); stream);
break; break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
break; break;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/activation.h>
#include "common/util/logging.h" #include "common/util/logging.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -25,6 +26,8 @@ namespace jax { ...@@ -25,6 +26,8 @@ namespace jax {
constexpr int kMaxNumDim = 8; constexpr int kMaxNumDim = 8;
size_t get_activation_len(NVTE_Activation_Type activation_enum);
struct Shape { struct Shape {
int num_dim; int num_dim;
size_t dims[kMaxNumDim]; size_t dims[kMaxNumDim];
......
...@@ -944,9 +944,15 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -944,9 +944,15 @@ class LayerNormMLP(TransformerEngineBase):
) and not self.return_layernorm_output and self.enable_layernorm ) and not self.return_layernorm_output and self.enable_layernorm
gated_act_pool = [('gelu', 'linear'), gated_act_pool = [('gelu', 'linear'),
('silu', 'linear')] ('silu', 'linear'),
('relu', 'linear'),
('quick_gelu', 'linear'),
('squared_relu', 'linear')]
act_pool = [('gelu',), act_pool = [('gelu',),
('silu',)] ('silu',),
('relu',),
('quick_gelu',),
('squared_relu',)]
normalize_acts = [] normalize_acts = []
for act in self.activations: for act in self.activations:
if not isinstance(act, str): if not isinstance(act, str):
......
...@@ -15,7 +15,7 @@ from .cpp_extensions import act_lu, act_lu_fp8, dact_lu ...@@ -15,7 +15,7 @@ from .cpp_extensions import act_lu, act_lu_fp8, dact_lu
from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
from .layernorm import canonicalize_layernorm_type from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes from .sharding import with_sharding_constraint_by_logical_axes
...@@ -56,72 +56,6 @@ def _activation_lu_bwd_rule(activation_type, ctx, g): ...@@ -56,72 +56,6 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule) _activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]):
"""
Activation Unit
"""
transpose_indices = (1, 2, 0)
dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv, fwd_dtype,
bwd_dtype, activation_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]):
output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv,
fwd_dtype, bwd_dtype, activation_type)
return output
def _activation_lu_fp8_fwd_rule(x,
dx_trans_no_use, # pylint: disable=unused-argument
dbias_no_use, # pylint: disable=unused-argument
amax,
scale, scale_inv,
fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
activation_type):
activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv, fwd_dtype,
activation_type)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x, amax, scale, scale_inv)
return activation_lu_out, ctx
def _activation_lu_fp8_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
activation_type,
ctx,
g):
x, amax, scale, scale_inv = ctx
if len(activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \
dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype, -1, activation_type)
dbias = jnp.empty(x.shape[-1], x.dtype)
else: #not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype,
-1, -2, activation_type)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
return ctx
_activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule)
def fused_layernorm_fp8_mlp(x: jnp.ndarray, def fused_layernorm_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
...@@ -231,6 +165,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -231,6 +165,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv) FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv) fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment