Unverified Commit 4c9626e7 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch][MoE] Enable New Recipes for Grouped Linear (#1525)



* Enable MXFP8 and Per-Tensor Current Scaling for Grouped Linear
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* enable float8blockwise
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* remove grouped linear parallel mode test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* internal=False for now
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* remove unused import
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6117b20c
......@@ -1470,7 +1470,6 @@ def _test_grouped_linear_accuracy(
if num_gemms > 1:
split_size = 1
if fp8:
if recipe.delayed():
split_size = 16
if recipe.mxfp8():
split_size = 128
......@@ -1509,12 +1508,11 @@ def _test_grouped_linear_accuracy(
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_grouped_linear_accuracy(
......@@ -1522,22 +1520,18 @@ def test_grouped_linear_accuracy(
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
if recipe.float8_block_scaling():
pytest.skip("Grouped linear for FP8 blockwise unsupported.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
......@@ -1591,24 +1585,7 @@ def test_grouped_linear_accuracy(
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
bs=2,
model="126m",
fp8=True,
recipe=recipe,
fp8_model_params=True,
parallel_mode=parallel_mode,
fuse_wgrad_accumulation=True,
)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
def test_grouped_linear_accuracy_single_gemm(recipe):
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
......@@ -1616,7 +1593,6 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
num_gemms=1,
bs=2,
model="126m",
fp8=True,
recipe=recipe,
fp8_model_params=True,
fuse_wgrad_accumulation=True,
......@@ -1626,9 +1602,12 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
"""Padding tensor shapes to multiples of 16."""
align_size = 16
if recipe.mxfp8():
align_size = 32
padded_tokens_per_expert = [
(num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
(num_tokens + align_size - 1) // align_size * align_size
for num_tokens in tokens_per_expert
]
hidden_states = torch.split(hidden_states, tokens_per_expert)
padded_hidden_states = []
......@@ -1729,12 +1708,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
if recipe.float8_block_scaling():
pytest.skip("Float8 block scaling unsupported for grouped linear.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
......
......@@ -553,14 +553,10 @@ def test_sanity_grouped_linear(
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8():
pytest.skip("Grouped linear does not support MXFP8")
if fp8_recipe.float8_current_scaling():
pytest.skip("Grouped linear does not support FP8 current scaling")
if fp8_recipe.float8_block_scaling():
pytest.skip("Grouped linear does not support FP8 block scaling")
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
......
......@@ -9,10 +9,9 @@ import os
import torch
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec, get_sm_count
from ..utils import get_sm_count
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
......@@ -174,14 +173,6 @@ def general_grouped_gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if isinstance(A[0], Float8TensorBase):
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a._data)
assert_dim_for_fp8_exec(b._data)
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
......@@ -208,6 +199,8 @@ def general_grouped_gemm(
for o in out
] # this should differ with respect to single output
# TODO: Move the swizzle to the C++ side. # pylint: disable=fixme
original_scale_inverses_list = [swizzle_inputs(A[i], B[i], layout) for i in range(num_gemms)]
bias = tex.te_general_grouped_gemm(
A,
transa,
......@@ -227,5 +220,7 @@ def general_grouped_gemm(
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
)
for i in range(num_gemms):
reset_swizzled_inputs(A[i], B[i], original_scale_inverses_list[i])
return out, bias, gelu_input
......@@ -101,18 +101,22 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
namespace transformer_engine::pytorch {
/***************************************************************************************************
* Transpose
**************************************************************************************************/
std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
std::optional<std::vector<py::handle>> output_list,
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list,
transformer_engine::DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
std::optional<at::Tensor> output = std::nullopt);
} // namespace transformer_engine::pytorch
namespace transformer_engine::pytorch {
/***************************************************************************************************
......
......@@ -196,12 +196,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"),
py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose",
py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype"));
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax"));
......
......@@ -109,6 +109,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
opts = opts.dtype(torch::kFloat32);
// TODO: Replace with an empty tensor.
at::Tensor scale_inv = at::reciprocal(scale);
py::object ret;
if (internal) {
......
......@@ -6,27 +6,38 @@
#include <optional>
#include "ATen/core/TensorBody.h"
#include "extensions.h"
#include "pybind.h"
std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
std::optional<std::vector<py::handle>> output_list,
namespace transformer_engine::pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list,
transformer_engine::DType otype) {
using namespace transformer_engine::pytorch;
init_extension();
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<py::object> py_output_objects_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto none = py::none();
if (output_list.has_value()) {
py_output_objects_list = output_list.value();
}
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool with_fused_kernel = true;
// create TE tensors from input
for (size_t i = 0; i < input_list.size(); i++) {
auto input_tensor = makeTransformerEngineTensor(input_list[i], none);
auto input_tensor = makeTransformerEngineTensor(input_list[i]);
const NVTEShape input_shape = input_tensor.shape();
transformer_engine::TensorWrapper output_tensor;
if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) {
with_fused_kernel = false;
}
if (output_list == std::nullopt) {
std::unique_ptr<Quantizer> quantizer = convert_quantizer(quantizer_list[i]);
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
......@@ -48,16 +59,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(),
"Number of input and output tensors must match");
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool with_fused_kernel = true;
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
const auto& tensor = nvte_tensor_output_list[i];
if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING) {
with_fused_kernel = false;
break;
}
if (nvte_tensor_columnwise_data(tensor) == nullptr) {
if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) {
with_fused_kernel = false;
break;
}
......@@ -68,10 +71,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
} else {
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
// TODO: switch to nvte_quantize_v2 with advanced numerical options
nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i],
at::cuda::getCurrentCUDAStream());
for (size_t i = 0; i < py_output_objects_list.size(); i++) {
quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt);
}
}
return py_output_objects_list;
......@@ -79,7 +80,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
std::optional<at::Tensor> output) {
using namespace transformer_engine::pytorch;
init_extension();
const auto dim = input.dim();
NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose.");
......@@ -106,3 +107,5 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
return out;
}
} // namespace transformer_engine::pytorch
......@@ -4,12 +4,13 @@
"""FP8 Padding API"""
from typing import Union, List
from typing import List, Optional, Tuple
import torch
import transformer_engine_torch as tex
from ..fp8 import FP8GlobalStateManager
from ..jit import no_torch_dynamo
......@@ -74,22 +75,30 @@ class Fp8Padding(torch.nn.Module):
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
align_size: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
"""
def __init__(
self,
num_gemms,
num_gemms: int,
align_size: Optional[int] = None,
) -> None:
super().__init__()
self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
m_splits: List[int],
) -> Union[torch.Tensor, List[int]]:
) -> Tuple[torch.Tensor, List[int]]:
"""
Apply the padding to the input.
......@@ -104,7 +113,12 @@ class Fp8Padding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits]
padded_m_splits = [
(m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits
]
# no padding needed
if m_splits == padded_m_splits:
return inp, m_splits
if torch.is_grad_enabled():
fn = _Fp8Padding.apply
......
......@@ -4,12 +4,13 @@
"""FP8 Padding API"""
from typing import List
from typing import List, Optional
import torch
import transformer_engine_torch as tex
from ..fp8 import FP8GlobalStateManager
from ..jit import no_torch_dynamo
......@@ -70,15 +71,23 @@ class Fp8Unpadding(torch.nn.Module):
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
align_size: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
"""
def __init__(
self,
num_gemms,
num_gemms: int,
align_size: Optional[int] = None,
) -> None:
super().__init__()
self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size
@no_torch_dynamo()
def forward(
......@@ -100,7 +109,12 @@ class Fp8Unpadding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits]
padded_m_splits = [
(m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits
]
# no padding needed
if m_splits == padded_m_splits:
return inp
if torch.is_grad_enabled():
fn = _Fp8Unpadding.apply
......
......@@ -9,6 +9,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from .base import (
get_multi_stream_cublas_workspace,
TransformerEngineBaseModule,
......@@ -37,7 +38,6 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.float8_tensor import Float8Tensor
from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.quantized_tensor import (
......@@ -47,7 +47,6 @@ from ..tensor.quantized_tensor import (
restore_from_saved,
)
__all__ = ["GroupedLinear"]
......@@ -85,15 +84,6 @@ class _GroupedLinear(torch.autograd.Function):
biases = weights_and_biases[num_gemms:]
device = inp.device
# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("GroupedLinear does not yet support MXFP8")
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling")
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling")
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
......@@ -126,7 +116,11 @@ class _GroupedLinear(torch.autograd.Function):
for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False)
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
inputmats = tex.fused_multi_quantize(
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
)
......@@ -167,7 +161,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits=m_splits,
bias=biases,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
use_split_accumulator=fprop_gemm_use_split_accumulator,
)
if fp8_calibration:
......@@ -182,6 +176,16 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weights_shape_1 = weights[0].shape[1]
# TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad:
for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensor):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensor):
weight.update_usage(columnwise_usage=True)
tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
......@@ -202,6 +206,7 @@ class _GroupedLinear(torch.autograd.Function):
ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
......@@ -247,6 +252,13 @@ class _GroupedLinear(torch.autograd.Function):
grad_biases = [None] * ctx.num_gemms
if ctx.fp8:
if ctx.use_bias:
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready
# for Float8BlockQuantizer.
if ctx.fp8_recipe.float8_block_scaling():
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i])
else:
for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output_mats[i], ctx.grad_output_quantizers[i]
......@@ -269,6 +281,13 @@ class _GroupedLinear(torch.autograd.Function):
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.requires_dgrad:
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_dgrad.use_split_accumulator
)
dgrad = torch.empty(
(sum(ctx.m_splits), ctx.weights_shape_1),
dtype=ctx.activation_dtype,
......@@ -285,10 +304,17 @@ class _GroupedLinear(torch.autograd.Function):
layout="NN",
m_splits=ctx.m_splits,
grad=True,
use_split_accumulator=_2X_ACC_DGRAD,
use_split_accumulator=dgrad_gemm_use_split_accumulator,
)
if ctx.weights_requires_grad:
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
if ctx.fuse_wgrad_accumulation:
wgrad_list = main_grads
else:
......@@ -308,7 +334,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits=ctx.m_splits,
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=_2X_ACC_WGRAD,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
)
for i in range(ctx.num_gemms):
......@@ -374,8 +400,8 @@ class _GroupedLinear(torch.autograd.Function):
None,
None,
None,
None, # is_grad_enabled
None, # is_grad_enabled
None,
None,
*wgrad_list,
*grad_biases,
)
......@@ -425,6 +451,9 @@ class GroupedLinear(TransformerEngineBaseModule):
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases.
The TP communication should be handled in the dispatch and combine stages of MoE models.
"""
def __init__(
......@@ -467,7 +496,11 @@ class GroupedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0}
self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
self._num_fp8_tensors_per_gemm = {
"fwd": 3,
"bwd": 2,
}
if tp_group is None:
self.tp_size = tp_size
......@@ -478,6 +511,12 @@ class GroupedLinear(TransformerEngineBaseModule):
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
if self.tp_size > 1 and bias:
raise ValueError(
"GroupedLinear doesn't support bias when TP > 1. "
"Because the TP communication is handled outside of this module."
)
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
......@@ -504,7 +543,7 @@ class GroupedLinear(TransformerEngineBaseModule):
),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i,
fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"],
)
# Construct bias parameters if needed
......@@ -529,12 +568,18 @@ class GroupedLinear(TransformerEngineBaseModule):
self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
else:
self.gemm_bias_unfused_add = False
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
assert not self.tp_size > 1, (
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
self._customize_quantizers_float8_current_scaling(fwd, recipe)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
......@@ -592,7 +637,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
"""
assert not isinstance(
inp, Float8Tensor
inp, QuantizedTensor
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
......@@ -617,20 +662,27 @@ class GroupedLinear(TransformerEngineBaseModule):
grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["input"] + i]
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = True
input_quantizers[i].internal = False
weight_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["weight"] + i]
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][self._offsets["input"] + i]
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
......@@ -645,7 +697,7 @@ class GroupedLinear(TransformerEngineBaseModule):
args += (
inp,
m_splits,
self.apply_bias and not self.gemm_bias_unfused_add,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
......@@ -665,17 +717,37 @@ class GroupedLinear(TransformerEngineBaseModule):
)
out = linear_fn(*args)
if self.gemm_bias_unfused_add:
out_shape = out.shape
out = torch.cat(
[
o + cast_if_needed(b, self.activation_dtype)
for o, b in zip(
torch.split(out.view(-1, self.out_features), m_splits), bias_tensors
)
]
).view(out_shape)
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert (
recipe.float8_current_scaling()
), "current scaling recipe quantizer customization here"
if fwd:
for i in range(self.num_gemms):
# set configs about amax epsilon and power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
# also set weight quantizer with same amax_epsilon & power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
else:
for i in range(self.num_gemms):
# set grad_output_quantizer with amax epsilon and power_2_scale
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment