Unverified Commit 32d1eb11 authored by Zhenhuan Liu's avatar Zhenhuan Liu Committed by GitHub
Browse files

FP8 Support for MCore MoE (#648)



* Add support for MoE with FP8.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Fix unittest.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Fix error in linear backward.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

---------
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent 9709147e
......@@ -9,7 +9,11 @@ from contextlib import nullcontext
import torch
import pytest
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
FP8GlobalStateManager,
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
......@@ -107,6 +111,7 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
all_boolean = [True, False]
batch_sizes_with_zero = [0, 1, 2]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
......@@ -456,6 +461,45 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs*config.seq_len
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params):
te_linear = (
Linear(
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype
)
.cuda()
)
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
out = te_linear(inp_hidden_states)
loss = out.sum()
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small", "weird"])
......
......@@ -22,6 +22,7 @@ def cast_to_fp8(
"""Cast input to FP8"""
if out is not None:
if inp.nelement() > 0:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale,
......@@ -32,6 +33,7 @@ def cast_to_fp8(
otype
)
return None
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
......@@ -41,7 +43,6 @@ def cast_to_fp8(
otype,
)
def cast_from_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
......
......@@ -64,6 +64,8 @@ def fp8_gemm(
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
if A.nelement() == 0 or B.nelement() == 0:
return out, gelu_input
args = (
A,
......@@ -191,6 +193,8 @@ def gemm(
grad_bias = empty_tensor
bias = bias if use_bias else empty_tensor
if A.nelement() == 0 or B.nelement() == 0:
return out, grad_bias, gelu_input
assert A.dtype == dtype and B.dtype == dtype, \
f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
......
......@@ -39,6 +39,7 @@ def fp8_cast_transpose_fused(
if noop_flag is None:
noop_flag = torch.Tensor()
if inp.nelement() > 0:
tex.fused_cast_transpose_noop(
inp,
noop_flag,
......
......@@ -19,6 +19,9 @@ at::Tensor cast_to_fp8(const at::Tensor &input,
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
if (input.numel() == 0)
return output;
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype,
amax.data_ptr(), scale.data_ptr(),
......
......@@ -83,6 +83,9 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
grad_output.size(0),
DType::kByte);
if (M == 0 || N == 0)
return {grad_bias, grad_output_cast, grad_output_transpose};
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
......@@ -335,6 +338,8 @@ at::Tensor fp8_transpose(at::Tensor input,
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
if (M == 0 || N == 0)
return input;
auto output =
allocateTorchTensor(input.size(1),
......
......@@ -824,6 +824,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
if get_rng_state_tracker is None:
init_fn(param)
else:
if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
with get_rng_state_tracker().fork(self.rng_tracker_name):
init_fn(param)
else:
with get_rng_state_tracker().fork():
init_fn(param)
......
......@@ -152,7 +152,6 @@ class _Linear(torch.autograd.Function):
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat
if fp8:
if _NVTE_DEBUG:
print('[Linear]: using FP8 forward')
......@@ -664,6 +663,10 @@ class Linear(TransformerEngineBaseModule):
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initilizeing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Configuration for splitting the weight and bias tensors along dim 0 into
multiple PyTorch parameters. If a list or tuple of strings is provided,
......@@ -723,6 +726,7 @@ class Linear(TransformerEngineBaseModule):
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
rng_tracker_name: Optional[str] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
......@@ -753,6 +757,8 @@ class Linear(TransformerEngineBaseModule):
), "Userbuffer communication backend not available."
self.ub_name = ub_name
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
if device == 'meta':
assert parameters_split is None, ("Cannot split module parameters "
"on 'meta' device.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment