Unverified Commit f85553ea authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] SwiGLU Implementation (#773)



* Implemented swiglu and silu
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Renamed nvte-*silu to nvte-*swish + generalized GetDBiasDact functions
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent b9954408
......@@ -15,15 +15,12 @@ from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import dgelu, dgelu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import gelu, gelu_fp8
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.mlp import activation_lu, activation_lu_fp8, fused_layernorm_fp8_mlp
GEMM_CASES = [
(256, 256, 512),
......@@ -37,6 +34,16 @@ LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available()
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == 'linear':
return lambda x: x
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
......@@ -174,22 +181,21 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 512, 128), (16384, 1024, 2816), (16384, 2816, 1024),
@pytest.mark.parametrize('m,n,k', [(128, 256, 512),
(16384, 1024, 2816),
(16384, 2816, 1024),
(16384, 1024, 1024)])
@pytest.mark.parametrize('activation_type', [('gelu', ),
('gelu', 'linear')])
('gelu', 'linear'),
('silu', ),
('silu', 'linear')])
@pytest.mark.parametrize('use_bias', [True, False])
def test_grad_fused_layernorm_fp8_mlp(self, m, n, k,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool):
activation_type: Sequence[Union[str, Callable]], use_bias: bool):
""" N/a """
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
activation_dict = {
('gelu', ): jax.nn.gelu
}
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
......@@ -218,15 +224,6 @@ class TestFP8Dot:
fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm",
activation_type = activation_type, use_bias = use_bias))
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == 'linear':
return lambda x: x
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
......@@ -249,15 +246,12 @@ class TestFP8Dot:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
if 'linear' in activation_type:
x = jnp.split(linear_1_out, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
else:
x = activation_dict[activation_type](linear_1_out)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
......@@ -331,7 +325,6 @@ class TestFP8Dot:
dtype=jnp.bfloat16)
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
key = jax.random.PRNGKey(0)
......@@ -340,190 +333,86 @@ def random_inputs_fixture(shape):
return out
class TestGeLu:
class TestActivationLu:
def ref_func(self, inputs):
func = jit(value_and_grad(lambda x: jnp.mean(jax.nn.gelu(x))))
return func(inputs)
def prim_func(self, inputs):
@jax.custom_vjp
def primitive(x):
out, _ = primitive_fwd(x)
return out
def primitive_fwd(x):
out = gelu(x)
ctx = x
return out, ctx
def ref_func(self, x, activation_type):
def ref_act_lu(inputs):
x = jnp.split(inputs, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
return jnp.mean(x)
def primitive_bwd(ctx, g):
x = ctx
out = dgelu(g, x)
return (out,)
ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
return ref_act_func(x)
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x: jnp.mean(primitive(x)))
return func(inputs)
def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type = self.activation_type))
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
@pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'),
('silu',),
('silu', 'linear')])
def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs
prim_out, prim_grad = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
self.activation_type = activation_type
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
class TestGeLuFP8(TestGeLu):
def prim_func(self, inputs):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
no_use = jnp.zeros(1, jnp.float32)
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)))
@jax.custom_vjp
def primitive(x, y, z, w):
out = primitive_fwd(x)
return out
prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
ref_out, (ref_grad,) = self.ref_func(x, activation_type)
""" prim_grad, = prim_grad """
""" ref_grad, = ref_grad """
def primitive_fwd(x, y, z, w):
out, _ = gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
out = dequantize(out, x.dtype, scale_inv)
ctx = x
return out, ctx
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
def primitive_bwd(ctx, g):
x = ctx
dgelu, dgelu_trans, dbias, amax_out = dgelu_dbias_cast_transpose(
g, x, amax, scale, scale_inv, jnp.float8_e5m2, -1)
dgelu = dequantize(dgelu, x.dtype, scale_inv)
dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
return dgelu, dgelu_trans, dbias, amax_out
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3))
class TestActivationLuFP8(TestActivationLu):
return func(inputs, jnp.transpose(inputs, (2, 0, 1)),
jnp.zeros(inputs.shape[-1], dtype=inputs.dtype), no_use)
def primitive_func(self, inputs, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv):
return jnp.mean(
activation_lu_fp8(inputs,
amax, scale, scale_inv,
jnp.float8_e4m3fn, jnp.float8_e5m2,
activation_type = self.activation_type))
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
@pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'),
('silu',),
('silu', 'linear')])
def test_activation_lu(self, random_inputs, activation_type):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
self.activation_type = activation_type
x = random_inputs
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, (2, 0, 1)),
dtype=FP8Helper.BWD_DTYPE)
class TestGatedGeLu:
def ref_func(self, inputs):
def jax_gated_gelu(x):
x = jnp.split(x, 2, axis=-2)
acts = [jax.nn.gelu(x[0]), x[1]]
x = functools.reduce(operator.mul, acts)
x = jnp.asarray(jnp.squeeze(x, -2), jnp.bfloat16)
return x
func = jit(value_and_grad(lambda x: jnp.mean(jax_gated_gelu(x))))
return func(inputs)
def prim_func(self, inputs):
@jax.custom_vjp
def primitive(x):
out, _ = primitive_fwd(x)
return out
def primitive_fwd(x):
out = gated_gelu(x)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
out = dgated_gelu(g, x)
return (out,)
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x: jnp.mean(primitive(x)))
return func(inputs)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gated_gelu(self, random_inputs):
x = random_inputs
prim_out, prim_grad = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
class TestGatedGeLuFP8(TestGatedGeLu):
def prim_func(self, inputs):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
no_use = jnp.zeros(1, jnp.float32)
@jax.custom_vjp
def primitive(x, y, z):
out = primitive_fwd(x)
return out
def primitive_fwd(x, y, z):
out, _ = gated_gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
out = dequantize(out, x.dtype, scale_inv)
ctx = x
return out, ctx
value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,)))
def primitive_bwd(ctx, g):
x = ctx
dgelu, dgelu_trans, amax_out = dgated_gelu_cast_transpose(g, x, amax, scale, scale_inv,
jnp.float8_e5m2, -1)
dgelu = dequantize(dgelu, x.dtype, scale_inv)
dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
return dgelu, dgelu_trans, amax_out
transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1)
dx_trans_no_use = jnp.zeros([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.zeros(x.shape[-1], dtype=x.dtype)
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2))
return func(inputs, jnp.transpose(inputs, (1, 2, 0)), no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gated_gelu(self, random_inputs):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
x = random_inputs
prim_out, (prim_grad, prim_grad_trans, amax) = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
prim_out, (prim_grad, prim_grad_trans, dbias, amax, _, _) = \
value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use,
self.amax, self.scale, self.scale_inv)
ref_out, (ref_grad,) = self.ref_func(x, activation_type)
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
if 'linear' not in activation_type:
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, (1, 2, 0)),
jnp.transpose(ref_grad, transpose_indices),
dtype=FP8Helper.BWD_DTYPE)
......
......@@ -158,6 +158,33 @@ ATTRS = [{
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu',)),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.8,
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('silu',)),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
......
......@@ -543,11 +543,25 @@ class LayerNormMLPAttr:
ACTIVATION: ('gelu', 'linear')
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
USE_BIAS: False,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear')
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('silu', 'linear')
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: False,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('silu', 'linear')
}]
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include "../util/vectorized_pointwise.h"
#include "../common.h"
namespace transformer_engine {
template <typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param&)>
void act_fn(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "act_lu_input");
CheckOutputTensor(*output, "act_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const ComputeType*>(output->scale.dptr),
reinterpret_cast<ComputeType*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
template <typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param&)>
void dact_fn(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "dact_lu_input");
CheckInputTensor(grad, "dact_lu_input_grad");
CheckOutputTensor(*output, "dact_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const ComputeType*>(output->scale.dptr),
reinterpret_cast<ComputeType*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
template <typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param&)>
void gated_act_fn(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, ComputeType, Param, OP>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const ComputeType*>(output->scale.dptr),
reinterpret_cast<ComputeType*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
template <typename ComputeType, typename Param,
ComputeType (*OP1)(ComputeType, const Param&),
ComputeType (*OP2)(ComputeType, const Param&)>
void dgated_act_fn(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output");
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, ComputeType, Param, OP1, OP2>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0],
grad.data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
......@@ -3,189 +3,16 @@
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <iostream>
#include "../utils.cuh"
#include "../common.h"
#include <cstdlib>
#include <../util/vectorized_pointwise.h>
#include "./activation_template.h"
#include "../util/math.h"
namespace transformer_engine {
void gelu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "gelu_input");
CheckOutputTensor(*output, "gelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Empty, gelu<fp32, fp32> >(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
Empty(),
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dgelu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "dgelu_input");
CheckInputTensor(grad, "dgelu_input_grad");
CheckOutputTensor(*output, "dgelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Empty, dgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void geglu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "geglu_input");
CheckOutputTensor(*output, "geglu_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, Empty, gelu<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dgeglu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(grad, "dgeglu_grad");
CheckInputTensor(input, "dgeglu_input");
CheckOutputTensor(*output, "dgeglu_output");
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0],
grad.data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void qgelu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "qgelu_input");
CheckOutputTensor(*output, "qgelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Empty, qgelu<fp32, fp32> >(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
Empty(),
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dqgelu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "dqgelu_input");
CheckInputTensor(grad, "dqgelu_input_grad");
CheckOutputTensor(*output, "dqgelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Empty, dqgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_gelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine;
gelu(*reinterpret_cast<const Tensor*>(input),
act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -196,7 +23,7 @@ void nvte_dgelu(const NVTETensor grad,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine;
dgelu(*reinterpret_cast<const Tensor*>(grad),
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
......@@ -207,7 +34,7 @@ void nvte_geglu(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
geglu(*reinterpret_cast<const Tensor*>(input),
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -218,7 +45,8 @@ void nvte_dgeglu(const NVTETensor grad,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgeglu(*reinterpret_cast<const Tensor*>(grad),
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
......@@ -229,7 +57,7 @@ void nvte_qgelu(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine;
qgelu(*reinterpret_cast<const Tensor*>(input),
act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -240,7 +68,7 @@ void nvte_dqgelu(const NVTETensor grad,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dqgelu(*reinterpret_cast<const Tensor*>(grad),
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
......
......@@ -4,134 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include "../util/vectorized_pointwise.h"
#include "./activation_template.h"
#include "../util/math.h"
#include "../common.h"
namespace transformer_engine {
void relu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "relu_input");
CheckOutputTensor(*output, "relu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Empty, relu<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void drelu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "drelu_input");
CheckInputTensor(grad, "drelu_input_grad");
CheckOutputTensor(*output, "drelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Empty, drelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void reglu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "reglu_input");
CheckOutputTensor(*output, "reglu_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, Empty, relu<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dreglu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(grad, "dreglu_grad");
CheckInputTensor(input, "dreglu_input");
CheckOutputTensor(*output, "dreglu_output");
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0],
grad.data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_relu);
using namespace transformer_engine;
relu(*reinterpret_cast<const Tensor*>(input),
act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -142,7 +24,7 @@ void nvte_drelu(const NVTETensor grad,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
using namespace transformer_engine;
drelu(*reinterpret_cast<const Tensor*>(grad),
dact_fn<fp32, Empty, drelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
......@@ -153,7 +35,7 @@ void nvte_reglu(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
reglu(*reinterpret_cast<const Tensor*>(input),
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -164,7 +46,8 @@ void nvte_dreglu(const NVTETensor grad,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine;
dreglu(*reinterpret_cast<const Tensor*>(grad),
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
......
......@@ -4,83 +4,38 @@
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include "../util/vectorized_pointwise.h"
#include "./activation_template.h"
#include "../util/math.h"
#include "../common.h"
namespace transformer_engine {
void swiglu(const Tensor &input,
Tensor *output,
void nvte_swish(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
CheckInputTensor(input, "geglu_input");
CheckOutputTensor(*output, "geglu_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, Empty, swish<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
{},
NVTE_API_CALL(nvte_swish);
using namespace transformer_engine;
act_fn<fp32, Empty, swish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dswiglu(const Tensor &grad,
const Tensor &input,
Tensor *output,
void nvte_dswish(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
CheckInputTensor(grad, "dswiglu_grad");
CheckInputTensor(input, "dswiglu_input");
CheckOutputTensor(*output, "dswiglu_output");
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, Empty, swish<fp32, fp32>, dswish<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0],
grad.data.shape[1],
{},
NVTE_API_CALL(nvte_dswish);
using namespace transformer_engine;
dact_fn<fp32, Empty, dswish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_swiglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
swiglu(*reinterpret_cast<const Tensor*>(input),
gated_act_fn<fp32, Empty, swish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -91,7 +46,8 @@ void nvte_dswiglu(const NVTETensor grad,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine;
dswiglu(*reinterpret_cast<const Tensor*>(grad),
dgated_act_fn<fp32, Empty, swish<fp32, fp32>, dswish<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
......
......@@ -61,24 +61,24 @@ void nvte_dgeglu(const NVTETensor grad,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute RELU activation of the input.
/*! \brief Compute SiLU activation of the input.
*
* \param[in] input Input tensor for RELU activation.
* \param[in] input Input tensor for GELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_relu(const NVTETensor input,
void nvte_swish(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute RELU activation gradient.
/*! \brief Compute Swish activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for RELU activation.
* \param[in] input Input tensor for Swish activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_drelu(const NVTETensor grad,
void nvte_dswish(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
......@@ -105,6 +105,29 @@ void nvte_dswiglu(const NVTETensor grad,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute RELU activation of the input.
*
* \param[in] input Input tensor for RELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute RELU activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for RELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_drelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute ReGLU activation of the input.
*
* \param[in] input Input tensor of shape [N, H * 2].
......
......@@ -159,6 +159,53 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
NVTETensor transposed_output,
cudaStream_t stream);
/*! \brief Compute backward of SiLU operation on the input, then cast and transpose. Additionally,
* reduce the result of the SiLU backward along the first dimension.
*
* This function produces 3 results:
* - `cast_output` is equal to `cast(dSiLU(input))`
* - `transposed_output` is equal to `transpose(cast(dSiLU(input)))`
* - `dbias` is equal to `reduce(dSiLU(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] swish_input Tensor used as input to the forward of SiLU operation.
* Shape [N, H].
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[out] dbias Result of the reduction of the dSiLU(input) along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias_dswish(const NVTETensor input,
const NVTETensor swish_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dswiglu of the input, additionally does cast and transpose the dswiglu output.
*
* This function produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] swiglu_input Tensor used as input to the forward of SwiGLU operation.
* Shape [N, H * 2].
* \param[in,out] cast_output Result of the cast. Shape: [N, H * 2].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dswiglu_cast_transpose(const NVTETensor input,
const NVTETensor swiglu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -619,7 +619,11 @@ void cast_transpose_dbias(const Tensor &input,
); // NOLINT(*)
}
template <int nvec_in, int nvec_out, typename Param>
// TODO Phuong: Change all the names in these generalized functions.
// For now, I keep the old names so that it is easier to do code review
template <typename ComputeType, typename ParamOP,
int nvec_in, int nvec_out, typename Param,
ComputeType (*OP)(ComputeType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_dgelu_kernel(const Param param,
......@@ -713,7 +717,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
after_dgelu[j].data.elt[k] = OP(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]);
}
}
......@@ -779,7 +783,9 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
}
}
template <int nvec_in, int nvec_out, typename Param>
template <typename ComputeType, typename ParamOP,
int nvec_in, int nvec_out, typename Param,
ComputeType (*OP)(ComputeType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
......@@ -896,7 +902,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
after_dgelu[j].data.elt[k] = OP(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]);
}
}
......@@ -969,7 +975,11 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
}
}
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
template <int nvec_in, int nvec_out,
typename CType, typename IType, typename OType,
typename ParamOP,
CType (*OP1)(CType, const ParamOP&),
CType (*OP2)(CType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
dgeglu_cast_transpose_kernel(const IType * const input,
......@@ -1068,11 +1078,11 @@ dgeglu_cast_transpose_kernel(const IType * const input,
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
after_dgelu[j].data.elt[k] = OP1(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
gelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {});
OP2(gelu_in[current_in ^ 1][j].data.elt[k], {});
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
......@@ -1138,7 +1148,11 @@ dgeglu_cast_transpose_kernel(const IType * const input,
}
}
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
template <int nvec_in, int nvec_out,
typename CType, typename IType, typename OType,
typename ParamOP,
CType (*OP1)(CType, const ParamOP&),
CType (*OP2)(CType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
......@@ -1265,11 +1279,11 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
after_dgelu[j].data.elt[k] = OP1(gelu_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
gelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k], {});
OP2(gelu_in[current_in ^ 1][j].data.elt[k], {});
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
......@@ -1343,6 +1357,8 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
}
}
template <typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
void cast_transpose_dbias_dgelu(const Tensor &input,
const Tensor &gelu_input,
Tensor *cast_output,
......@@ -1407,7 +1423,7 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
using ComputeType = fp32;
// using ComputeType = fp32;
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) *
sizeof(Vec<OutputType, nvec_out>);
......@@ -1423,20 +1439,28 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>,
cudaFuncSetAttribute(
cast_transpose_dbias_dgelu_kernel<ComputeType, Empty,
nvec_in, nvec_out, Param, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>
cast_transpose_dbias_dgelu_kernel<ComputeType, Empty,
nvec_in, nvec_out, Param, OP>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel_notaligned<
ComputeType, Empty,
nvec_in, nvec_out, Param, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_dgelu_kernel_notaligned<nvec_in, nvec_out, Param>
cast_transpose_dbias_dgelu_kernel_notaligned<
ComputeType, Empty,
nvec_in, nvec_out, Param, OP>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
......@@ -1448,6 +1472,9 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
); // NOLINT(*)
}
template <typename ComputeType, typename ParamOP,
ComputeType (*OP1)(ComputeType, const ParamOP&),
ComputeType (*OP2)(ComputeType, const ParamOP&)>
void dgeglu_cast_transpose(const Tensor &input,
const Tensor &geglu_input,
Tensor *cast_output,
......@@ -1505,11 +1532,14 @@ void dgeglu_cast_transpose(const Tensor &input,
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
if (full_tile) {
cudaFuncSetAttribute(dgeglu_cast_transpose_kernel<nvec_in, nvec_out, fp32,
InputType, OutputType>,
cudaFuncSetAttribute(dgeglu_cast_transpose_kernel<
nvec_in, nvec_out,
ComputeType, InputType, OutputType,
Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
dgeglu_cast_transpose_kernel<nvec_in, nvec_out, fp32, InputType, OutputType>
dgeglu_cast_transpose_kernel< nvec_in, nvec_out,
ComputeType, InputType, OutputType, Empty, OP1, OP2>
<<<n_blocks,
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
......@@ -1524,11 +1554,14 @@ void dgeglu_cast_transpose(const Tensor &input,
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(dgeglu_cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32,
InputType, OutputType>,
cudaFuncSetAttribute(dgeglu_cast_transpose_kernel_notaligned<
nvec_in, nvec_out,
ComputeType, InputType, OutputType,
Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
dgeglu_cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32, InputType, OutputType>
dgeglu_cast_transpose_kernel_notaligned<nvec_in, nvec_out,
ComputeType, InputType, OutputType, Empty, OP1, OP2>
<<<n_blocks,
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
......@@ -1574,7 +1607,8 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine;
cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input),
cast_transpose_dbias_dgelu<fp32, Empty, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
......@@ -1590,9 +1624,44 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine;
dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input),
dgeglu_cast_transpose<fp32, Empty, dgelu<fp32, fp32>, gelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(geglu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
void nvte_cast_transpose_dbias_dswish(const NVTETensor input,
const NVTETensor swish_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dswish);
using namespace transformer_engine;
cast_transpose_dbias_dgelu<fp32, Empty, dswish<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(swish_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
}
void nvte_dswiglu_cast_transpose(const NVTETensor input,
const NVTETensor swiglu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu_cast_transpose);
using namespace transformer_engine;
dgeglu_cast_transpose<fp32, Empty, dswish<fp32, fp32>, swish<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(swiglu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
This diff is collapsed.
......@@ -34,6 +34,16 @@ pybind11::dict Registrations() {
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
// TODO
dict["te_silu"] = EncapsulateFunction(Silu);
dict["te_silu_fp8"] = EncapsulateFunction(SiluFP8);
dict["te_dsilu"] = EncapsulateFunction(DSilu);
dict["te_dsilu_dbias_cast_transpose"] = EncapsulateFunction(DSiluDBiasCastTranspose);
dict["te_gated_silu"] = EncapsulateFunction(GatedSilu);
dict["te_gated_silu_fp8"] = EncapsulateFunction(GatedSiluFP8);
dict["te_dgated_silu"] = EncapsulateFunction(DGatedSilu);
dict["te_dgated_silu_cast_transpose"] = EncapsulateFunction(DGatedSiluCastTranspose);
//
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
......@@ -66,7 +76,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes);
m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
......
......@@ -14,6 +14,7 @@
#include <stdexcept>
#include <string>
#include <vector>
#include <iostream>
#include "common/common.h"
#include "common/util/logging.h"
......@@ -234,30 +235,6 @@ void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu
nvte_dgelu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}
pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto gelu_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gelu_input_tensor = TensorWrapper(nullptr, gelu_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
......@@ -466,6 +443,241 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op
output_trans_tensor.data(), stream);
}
void SiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_swish(input_tensor.data(), output_tensor.data(), stream);
}
void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output);
}
void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto silu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dswish(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream);
}
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
// For now, all dbias_dact(-s) have the same workspace size
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto silu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias_dswish(input_tensor.data(), silu_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
}
void GatedSiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
}
void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr,
output);
}
void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto silu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dswiglu(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream);
}
void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = desc.shape.to_vector();
auto silu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_dswiglu_cast_transpose(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream);
}
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
......
......@@ -140,13 +140,14 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// TODO (Phuong): Templating these 9x2 rountines before adding ReGLU, QuickGeLU, Squared ReLu
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
......@@ -167,6 +168,24 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
......
......@@ -943,17 +943,18 @@ class LayerNormMLP(TransformerEngineBase):
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
# Make sure each tuple is sorted in alphabet order
gated_act_pool = [('gelu', 'linear')]
#('linear', 'silu')] coming
act_pool = [('gelu',)]
#('silu',)] coming
gated_act_pool = [('gelu', 'linear'),
('silu', 'linear')]
act_pool = [('gelu',),
('silu',)]
normalize_acts = []
for act in self.activations:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
normalize_acts = tuple(sorted(normalize_acts))
normalize_acts = tuple(reversed(normalize_acts)
if normalize_acts[0] == 'linear' else normalize_acts)
is_gated = normalize_acts in gated_act_pool
is_act_implemented = normalize_acts in (gated_act_pool + act_pool)
......
......@@ -15,9 +15,13 @@ from .cpp_extensions import gelu
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import silu, silu_fp8
from .cpp_extensions import dsilu, dsilu_dbias_cast_transpose
from .cpp_extensions import gated_silu, gated_silu_fp8
from .cpp_extensions import dgated_silu, dgated_silu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize
from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
......@@ -27,14 +31,22 @@ activation_dict = {
('gelu',): {'fwd': gelu,
"bwd": dgelu},
('gelu', 'linear'): {'fwd': gated_gelu,
'bwd': dgated_gelu}
'bwd': dgated_gelu},
('silu',): {'fwd': silu,
"bwd": dsilu },
('silu', 'linear'): {'fwd': gated_silu,
'bwd': dgated_silu}
}
activation_fp8_dict = {
('gelu',): {'fwd': gelu_fp8,
'bwd': dgelu_dbias_cast_transpose},
('gelu', 'linear'): {'fwd': gated_gelu_fp8,
'bwd': dgated_gelu_cast_transpose}
'bwd': dgated_gelu_cast_transpose},
('silu',): { 'fwd': silu_fp8,
'bwd': dsilu_dbias_cast_transpose },
('silu', 'linear'): { 'fwd': gated_silu_fp8,
'bwd': dgated_silu_cast_transpose }
}
......@@ -47,7 +59,6 @@ def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]
output = _activation_lu(x, activation_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
......@@ -55,12 +66,10 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable
return _output
def _activation_lu_fwd_rule(x, activation_type):
fwd_output = activation_dict[activation_type]["fwd"](x)
return fwd_output, (x,)
def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx
assert x.dtype == g.dtype
......@@ -72,6 +81,67 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, fwd_dtype:jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]):
"""
Activation Unit
"""
transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1)
dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax,
scale, scale_inv, fwd_dtype, bwd_dtype, activation_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(6,7,8))
def _activation_lu_fp8(x: jnp.ndarray,
dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]]):
output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax,
scale, scale_inv, fwd_dtype, bwd_dtype,
activation_type)
return output
def _activation_lu_fp8_fwd_rule(x,
dx_trans_no_use, # pylint: disable=unused-argument
dbias_no_use, # pylint: disable=unused-argument
amax,
scale, scale_inv,
fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
activation_type):
activation_lu_out, _ = activation_fp8_dict[activation_type ]["fwd"](
x, amax, scale, scale_inv, fwd_dtype)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x, amax, scale, scale_inv)
return activation_lu_out, ctx
def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
activation_type, ctx, g):
x, amax, scale, scale_inv = ctx
activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"]
if len(activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \
activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1)
dbias = jnp.empty(x.shape[-1], x.dtype)
else:
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
return ctx
_activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule)
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
......@@ -247,11 +317,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_lu_out_scale = scale[gemm2_x_idx]
activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]
activation_lu_fp8 = activation_fp8_dict[activation_type]["fwd"]
activation_lu_fwd_fp8 = activation_fp8_dict[activation_type]["fwd"]
# (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out, updated_activation_lu_amax = activation_lu_fp8(dot_1_output,
activation_lu_out_amax, activation_lu_out_scale,
casted_activation_lu_out, updated_activation_lu_amax = \
activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment