"vscode:/vscode.git/clone" did not exist on "ad5f2fe34cf2b3564a2d71500a7a096e25065734"
Unverified Commit aad4e173 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Generalizing Activation Primitives (#810)



* templated primitives and respective C++ functions
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes for LayerNormMLP, tests in test_custom_compute all passed
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added default arg for pybind get_workspace_size funcs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes for TestTransFormer with non-gated act tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* renamed gelu to act
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* improved enum implementation, avoid using magic numbers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Exposed C++ ActivationEnum to python side
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Changed error messages
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* changed conditional check on input shape for dbias_cast_transpose
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* changed dtype (tol) for bias grad tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes so that layer_norm_fp8_mlp can take bias = None
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Set bias = None in flax modules
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 2045a426
...@@ -194,8 +194,8 @@ class TestFP8Dot: ...@@ -194,8 +194,8 @@ class TestFP8Dot:
b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16) b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else: else:
b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16) b1 = None
b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16) b2 = None
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros( init_fp8_metas_amax = jnp.zeros(
...@@ -300,19 +300,19 @@ class TestFP8Dot: ...@@ -300,19 +300,19 @@ class TestFP8Dot:
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32), jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32), assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
if use_bias: if use_bias:
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32), jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16) dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
...@@ -341,13 +341,14 @@ class TestActivationLu: ...@@ -341,13 +341,14 @@ class TestActivationLu:
def primitive_func(self, inputs): def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type = self.activation_type)) return jnp.mean(activation_lu(inputs, activation_type = self.activation_type))
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'), ('gelu', 'linear'),
('silu',), ('silu',),
('silu', 'linear')]) ('silu', 'linear')])
def test_activation_lu(self, random_inputs, activation_type): def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
self.activation_type = activation_type self.activation_type = activation_type
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
...@@ -355,8 +356,6 @@ class TestActivationLu: ...@@ -355,8 +356,6 @@ class TestActivationLu:
prim_out, (prim_grad,) = value_n_grad_primitive_func(x) prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
ref_out, (ref_grad,) = self.ref_func(x, activation_type) ref_out, (ref_grad,) = self.ref_func(x, activation_type)
""" prim_grad, = prim_grad """
""" ref_grad, = ref_grad """
assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
...@@ -372,7 +371,7 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -372,7 +371,7 @@ class TestActivationLuFP8(TestActivationLu):
activation_type = self.activation_type)) activation_type = self.activation_type))
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'), ('gelu', 'linear'),
('silu',), ('silu',),
...@@ -384,6 +383,7 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -384,6 +383,7 @@ class TestActivationLuFP8(TestActivationLu):
self.activation_type = activation_type self.activation_type = activation_type
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,))) value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,)))
......
...@@ -529,11 +529,12 @@ void cast_transpose_dbias(const Tensor &input, ...@@ -529,11 +529,12 @@ void cast_transpose_dbias(const Tensor &input,
Tensor *dbias, Tensor *dbias,
Tensor *workspace, Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
// TODO if (workspace->data.dptr != nullptr) {
// CheckInputTensor(input, "cast_transpose_dbias_input"); CheckInputTensor(input, "cast_transpose_dbias_input");
// CheckOutputTensor(*cast_output, "cast_output"); CheckOutputTensor(*cast_output, "cast_output");
// CheckOutputTensor(*transposed_output, "transposed_output"); CheckOutputTensor(*transposed_output, "transposed_output");
// CheckOutputTensor(*dbias, "dbias"); CheckOutputTensor(*dbias, "dbias");
}
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
......
This diff is collapsed.
...@@ -25,25 +25,14 @@ pybind11::dict Registrations() { ...@@ -25,25 +25,14 @@ pybind11::dict Registrations() {
pybind11::dict dict; pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose); dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose); dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_gelu"] = EncapsulateFunction(Gelu);
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8); dict["te_act_lu"] = EncapsulateFunction(ActLu);
dict["te_dgelu"] = EncapsulateFunction(DGelu); dict["te_act_lu_fp8"] = EncapsulateFunction(ActLuFP8);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose); dict["te_dact_lu"] = EncapsulateFunction(DActLu);
dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose); dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu); dict["te_dact_lu_dbias_cast_transpose"] = EncapsulateFunction(DActLuDBiasCastTranspose);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_dgated_act_lu_cast_transpose"] = EncapsulateFunction(DGatedActLuCastTranspose);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
// TODO
dict["te_silu"] = EncapsulateFunction(Silu);
dict["te_silu_fp8"] = EncapsulateFunction(SiluFP8);
dict["te_dsilu"] = EncapsulateFunction(DSilu);
dict["te_dsilu_dbias_cast_transpose"] = EncapsulateFunction(DSiluDBiasCastTranspose);
dict["te_gated_silu"] = EncapsulateFunction(GatedSilu);
dict["te_gated_silu_fp8"] = EncapsulateFunction(GatedSiluFP8);
dict["te_dgated_silu"] = EncapsulateFunction(DGatedSilu);
dict["te_dgated_silu_cast_transpose"] = EncapsulateFunction(DGatedSiluCastTranspose);
//
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward); dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8); dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward); dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
...@@ -67,8 +56,11 @@ pybind11::dict Registrations() { ...@@ -67,8 +56,11 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) { PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations); m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor,
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor); pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor,
pybind11::arg(), pybind11::arg(), pybind11::arg(),
pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
...@@ -109,6 +101,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -109,6 +101,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
pybind11::enum_<NVTE_Activation_Enum>(m, "NVTE_Activation_Enum", pybind11::module_local())
.value("GELU", NVTE_Activation_Enum::GELU)
.value("GEGLU", NVTE_Activation_Enum::GEGLU)
.value("SILU", NVTE_Activation_Enum::SILU)
.value("SWIGLU", NVTE_Activation_Enum::SWIGLU);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
......
This diff is collapsed.
...@@ -43,14 +43,24 @@ struct Shape { ...@@ -43,14 +43,24 @@ struct Shape {
} }
}; };
enum class NVTE_Activation_Enum {
GELU,
GEGLU,
SILU,
SWIGLU,
};
size_t get_activation_len(NVTE_Activation_Enum act_enum);
struct CustomCallCommonDescriptor { struct CustomCallCommonDescriptor {
Shape shape; Shape shape;
DType in_dtype; DType in_dtype;
DType out_dtype; DType out_dtype;
size_t act_enum;
}; };
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype, pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype); DType out_dtype, size_t act_enum = 0);
struct CustomCallCommonWkDescriptor { struct CustomCallCommonWkDescriptor {
Shape shape; Shape shape;
...@@ -58,11 +68,13 @@ struct CustomCallCommonWkDescriptor { ...@@ -58,11 +68,13 @@ struct CustomCallCommonWkDescriptor {
DType in_dtype; DType in_dtype;
DType out_dtype; DType out_dtype;
DType wk_dtype; DType wk_dtype;
size_t act_enum;
}; };
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape, pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype, const std::vector<size_t> &wkshape,
DType out_dtype, DType wk_dtype); DType in_dtype, DType out_dtype, DType wk_dtype,
size_t act_enum = 0);
struct CustomCallNormDescriptor { struct CustomCallNormDescriptor {
size_t batch_size; size_t batch_size;
...@@ -140,17 +152,16 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o ...@@ -140,17 +152,16 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// TODO (Phuong): Templating these 9x2 rountines before adding ReGLU, QuickGeLU, Squared ReLu void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype); DType in_dtype, DType out_dtype);
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
...@@ -159,31 +170,7 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi ...@@ -159,31 +170,7 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
......
...@@ -955,7 +955,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -955,7 +955,6 @@ class LayerNormMLP(TransformerEngineBase):
normalize_acts = tuple(reversed(normalize_acts) normalize_acts = tuple(reversed(normalize_acts)
if normalize_acts[0] == 'linear' else normalize_acts) if normalize_acts[0] == 'linear' else normalize_acts)
is_gated = normalize_acts in gated_act_pool
is_act_implemented = normalize_acts in (gated_act_pool + act_pool) is_act_implemented = normalize_acts in (gated_act_pool + act_pool)
use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\ use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
...@@ -1052,8 +1051,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1052,8 +1051,8 @@ class LayerNormMLP(TransformerEngineBase):
axes=self.bias_axes_2) axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype) bias_2 = bias_2.astype(self.dtype)
else: else:
bias_1 = jnp.empty(0, self.dtype) bias_1 = None
bias_2 = jnp.empty(0, self.dtype) bias_2 = None
out = fused_layernorm_fp8_mlp(y, out = fused_layernorm_fp8_mlp(y,
scale, scale,
...@@ -1134,7 +1133,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1134,7 +1133,6 @@ class LayerNormMLP(TransformerEngineBase):
x += jnp.reshape(bias_1, bias_1_shape) x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name) x = checkpoint_name(x, ffn1_ckpt_name)
activations = [] activations = []
if is_act_implemented: if is_act_implemented:
z = activation_lu(x, normalize_acts) z = activation_lu(x, normalize_acts)
...@@ -1144,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1144,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
z = functools.reduce(operator.mul, activations) z = functools.reduce(operator.mul, activations)
if not is_gated: if num_activations == 1:
z = jnp.reshape(z, (*z.shape[:-2], -1)) z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate, z = nn.Dropout(rate=self.intermediate_dropout_rate,
......
...@@ -11,14 +11,8 @@ import jax.numpy as jnp ...@@ -11,14 +11,8 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
from .cpp_extensions import gelu from .cpp_extensions import act_lu, act_lu_fp8, dact_lu
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import silu, silu_fp8
from .cpp_extensions import dsilu, dsilu_dbias_cast_transpose
from .cpp_extensions import gated_silu, gated_silu_fp8
from .cpp_extensions import dgated_silu, dgated_silu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize
...@@ -26,44 +20,6 @@ from .layernorm import canonicalize_layernorm_type ...@@ -26,44 +20,6 @@ from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes from .sharding import with_sharding_constraint_by_logical_axes
activation_dict = {
('gelu',): {
'fwd': gelu,
"bwd": dgelu
},
('gelu', 'linear'): {
'fwd': gated_gelu,
'bwd': dgated_gelu
},
('silu',): {
'fwd': silu,
"bwd": dsilu
},
('silu', 'linear'): {
'fwd': gated_silu,
'bwd': dgated_silu
}
}
activation_fp8_dict = {
('gelu',): {
'fwd': gelu_fp8,
'bwd': dgelu_dbias_cast_transpose
},
('gelu', 'linear'): {
'fwd': gated_gelu_fp8,
'bwd': dgated_gelu_cast_transpose
},
('silu',): {
'fwd': silu_fp8,
'bwd': dsilu_dbias_cast_transpose
},
('silu', 'linear'): {
'fwd': gated_silu_fp8,
'bwd': dgated_silu_cast_transpose
}
}
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
""" """
...@@ -84,7 +40,7 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable ...@@ -84,7 +40,7 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable
def _activation_lu_fwd_rule(x, activation_type): def _activation_lu_fwd_rule(x, activation_type):
fwd_output = activation_dict[activation_type]["fwd"](x) fwd_output = act_lu(x, activation_type)
return fwd_output, (x,) return fwd_output, (x,)
...@@ -92,7 +48,7 @@ def _activation_lu_bwd_rule(activation_type, ctx, g): ...@@ -92,7 +48,7 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx x, = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = activation_dict[activation_type]["bwd"](g, x) dx = dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape) dx = jnp.reshape(dx, x.shape)
return (dx,) return (dx,)
...@@ -106,7 +62,7 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, sca ...@@ -106,7 +62,7 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, sca
""" """
Activation Unit Activation Unit
""" """
transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1) transpose_indices = (1, 2, 0)
dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype) dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
...@@ -127,19 +83,15 @@ def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_us ...@@ -127,19 +83,15 @@ def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_us
return output return output
def _activation_lu_fp8_fwd_rule( def _activation_lu_fp8_fwd_rule(x,
x,
dx_trans_no_use, # pylint: disable=unused-argument dx_trans_no_use, # pylint: disable=unused-argument
dbias_no_use, # pylint: disable=unused-argument dbias_no_use, # pylint: disable=unused-argument
amax, amax,
scale, scale, scale_inv,
scale_inv, fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
activation_type): activation_type):
activation_lu_out, _ = activation_fp8_dict[activation_type]["fwd"](x, amax, scale, scale_inv, activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv, fwd_dtype,
fwd_dtype) activation_type)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv) activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x, amax, scale, scale_inv) ctx = (x, amax, scale, scale_inv)
return activation_lu_out, ctx return activation_lu_out, ctx
...@@ -153,14 +105,14 @@ def _activation_lu_fp8_bwd_rule( ...@@ -153,14 +105,14 @@ def _activation_lu_fp8_bwd_rule(
g): g):
x, amax, scale, scale_inv = ctx x, amax, scale, scale_inv = ctx
activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"]
if len(activation_type) > 1: #gated, no bias if len(activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \ dactivation_lu, dactivation_lu_trans, amax_out = \
activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1) dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype, -1, activation_type)
dbias = jnp.empty(x.shape[-1], x.dtype) dbias = jnp.empty(x.shape[-1], x.dtype)
else: else: #not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \ dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1) dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype,
-1, -2, activation_type)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv) dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv) dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv) ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
...@@ -262,7 +214,6 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -262,7 +214,6 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_type, activation_type,
use_bias): use_bias):
is_gated = len(activation_type) > 1
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out) # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out) # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
...@@ -276,15 +227,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -276,15 +227,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0] assert kernel_1.shape[-1] == kernel_2.shape[0]
# Squeeze act axis
# (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
if not is_gated:
kernel_1 = jnp.squeeze(kernel_1, axis=-2)
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \ maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv) FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv) fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
...@@ -337,8 +282,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -337,8 +282,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
(x_contracting_dims, (0,)), (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias: if use_bias:
bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = bias_1.shape
dot_1_output += jnp.reshape(bias_1, bias_1_shape) bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
else:
bias_1_shape = None
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
...@@ -347,12 +295,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -347,12 +295,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_lu_out_scale = scale[gemm2_x_idx] activation_lu_out_scale = scale[gemm2_x_idx]
activation_lu_out_scale_inv = scale_inv[gemm2_x_idx] activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]
activation_lu_fwd_fp8 = activation_fp8_dict[activation_type]["fwd"]
# (batch..., hidden_in) -> (batch..., hidden) # (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out, updated_activation_lu_amax = \ casted_activation_lu_out, updated_activation_lu_amax = \
activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale, act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype) activation_lu_out_scale_inv, fwd_dtype, activation_type)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes( casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
casted_activation_lu_out, dot_2_input_axes) casted_activation_lu_out, dot_2_input_axes)
...@@ -370,15 +317,18 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -370,15 +317,18 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias: if use_bias:
bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape bias_2_shape = bias_2.shape
dot_2_output += jnp.reshape(bias_2, bias_2_shape) bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
else:
bias_2_shape = None
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1, ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape, maybe_fp32_to_fm32) x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32)
return dot_2_output, ctx return dot_2_output, ctx
...@@ -403,8 +353,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -403,8 +353,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
is_gated = len(activation_type) > 1
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1] grad_amax = amax[gemm2_grad_idx, 0:1]
...@@ -413,7 +361,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -413,7 +361,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
# Since the sharding of outputs should be the same as dot_1's input # Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
if use_bias: if use_bias:
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \ casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
dbias_cast_transpose(grad, grad_amax, grad_scale, dbias_cast_transpose(grad, grad_amax, grad_scale,
...@@ -427,7 +374,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -427,7 +374,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
grad_scale_inv, bwd_dtype, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, static_axis_boundary=-1,
transpose_axis_boundary=-1) transpose_axis_boundary=-1)
dbias_2 = jnp.empty(bias_2_shape, grad.dtype) dbias_2 = None
casted_activation_lu_out_t = transpose(casted_activation_lu_out, casted_activation_lu_out_t = transpose(casted_activation_lu_out,
static_axis_boundary=-1, static_axis_boundary=-1,
...@@ -453,11 +400,9 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -453,11 +400,9 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale = scale[gemm1_grad_idx] dactivation_lu_scale = scale[gemm1_grad_idx]
dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx] dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]
dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"] if len(activation_type) > 1: # if gated
dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output)
if is_gated:
if use_bias: if use_bias:
dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
dbias_cast_transpose( dbias_cast_transpose(
dactivation_lu, dactivation_lu,
...@@ -470,19 +415,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -470,19 +415,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dbias_1 = jnp.reshape(dbias_1, bias_1_shape) dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else: else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
dactivation_lu_cast_transpose( dgated_act_lu_cast_transpose(
dgrad_2, dgrad_2,
dot_1_output, dot_1_output,
dactivation_lu_amax, dactivation_lu_amax,
dactivation_lu_scale, dactivation_lu_scale,
dactivation_lu_scale_inv, dactivation_lu_scale_inv,
bwd_dtype, bwd_dtype,
static_axis_boundary=-1) static_axis_boundary=-1,
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) activation_type=activation_type)
dbias_1 = None
else: else:
if use_bias: if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\
dactivation_lu_cast_transpose( dact_lu_dbias_cast_transpose(
dgrad_2, dgrad_2,
dot_1_output, dot_1_output,
dactivation_lu_amax, dactivation_lu_amax,
...@@ -490,9 +436,11 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -490,9 +436,11 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv, dactivation_lu_scale_inv,
bwd_dtype, bwd_dtype,
static_axis_boundary=-1, static_axis_boundary=-1,
transpose_axis_boundary=-1) transpose_axis_boundary=-2,
activation_type=activation_type)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape) dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else: else:
dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
cast_transpose( cast_transpose(
dactivation_lu, dactivation_lu,
...@@ -501,28 +449,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -501,28 +449,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv, dactivation_lu_scale_inv,
bwd_dtype, bwd_dtype,
static_axis_boundary=-1, static_axis_boundary=-1,
transpose_axis_boundary=-1) transpose_axis_boundary=-2)
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) dbias_1 = None
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...) # (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx] gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
xt_batch_dims_2 = xt_batch_dims if not is_gated \ xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
else tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2), dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# Expand act axis to match the shape with the given kernel_1
if not is_gated:
wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
# (batch..., hidden_out) x (hidden_in, hidden_out) x_contracting_dims = ((min(x_contracting_dims),) + tuple(
if is_gated: i + 1 for i in x_contracting_dims), (1,2))
x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
(1, 2))
else:
x_contracting_dims = (x_contracting_dims, (1,))
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv, dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
kernel_1_scale_inv, grad.dtype, x_contracting_dims, kernel_1_scale_inv, grad.dtype, x_contracting_dims,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment