Unverified Commit 4c9626e7 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch][MoE] Enable New Recipes for Grouped Linear (#1525)



* Enable MXFP8 and Per-Tensor Current Scaling for Grouped Linear
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* enable float8blockwise
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* remove grouped linear parallel mode test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* internal=False for now
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* remove unused import
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6117b20c
...@@ -1470,8 +1470,7 @@ def _test_grouped_linear_accuracy( ...@@ -1470,8 +1470,7 @@ def _test_grouped_linear_accuracy(
if num_gemms > 1: if num_gemms > 1:
split_size = 1 split_size = 1
if fp8: if fp8:
if recipe.delayed(): split_size = 16
split_size = 16
if recipe.mxfp8(): if recipe.mxfp8():
split_size = 128 split_size = 128
m = config.seq_len // split_size m = config.seq_len // split_size
...@@ -1509,12 +1508,11 @@ def _test_grouped_linear_accuracy( ...@@ -1509,12 +1508,11 @@ def _test_grouped_linear_accuracy(
return outputs return outputs
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_grouped_linear_accuracy( def test_grouped_linear_accuracy(
...@@ -1522,22 +1520,18 @@ def test_grouped_linear_accuracy( ...@@ -1522,22 +1520,18 @@ def test_grouped_linear_accuracy(
num_gemms, num_gemms,
bs, bs,
model, model,
fp8,
recipe, recipe,
fp8_model_params, fp8_model_params,
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
parallel_mode=None, parallel_mode=None,
): ):
fp8 = recipe is not None
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip("MXFP8 unsupported for grouped linear.") pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
if recipe.float8_block_scaling():
pytest.skip("Grouped linear for FP8 blockwise unsupported.")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
...@@ -1591,24 +1585,7 @@ def test_grouped_linear_accuracy( ...@@ -1591,24 +1585,7 @@ def test_grouped_linear_accuracy(
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("parallel_mode", ["column", "row"]) @pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
bs=2,
model="126m",
fp8=True,
recipe=recipe,
fp8_model_params=True,
parallel_mode=parallel_mode,
fuse_wgrad_accumulation=True,
)
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_single_gemm(recipe): def test_grouped_linear_accuracy_single_gemm(recipe):
"""Split the tests to save CI time""" """Split the tests to save CI time"""
test_grouped_linear_accuracy( test_grouped_linear_accuracy(
...@@ -1616,7 +1593,6 @@ def test_grouped_linear_accuracy_single_gemm(recipe): ...@@ -1616,7 +1593,6 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
num_gemms=1, num_gemms=1,
bs=2, bs=2,
model="126m", model="126m",
fp8=True,
recipe=recipe, recipe=recipe,
fp8_model_params=True, fp8_model_params=True,
fuse_wgrad_accumulation=True, fuse_wgrad_accumulation=True,
...@@ -1626,9 +1602,12 @@ def test_grouped_linear_accuracy_single_gemm(recipe): ...@@ -1626,9 +1602,12 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
"""Padding tensor shapes to multiples of 16.""" align_size = 16
if recipe.mxfp8():
align_size = 32
padded_tokens_per_expert = [ padded_tokens_per_expert = [
(num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert (num_tokens + align_size - 1) // align_size * align_size
for num_tokens in tokens_per_expert
] ]
hidden_states = torch.split(hidden_states, tokens_per_expert) hidden_states = torch.split(hidden_states, tokens_per_expert)
padded_hidden_states = [] padded_hidden_states = []
...@@ -1729,12 +1708,8 @@ def test_padding_grouped_linear_accuracy( ...@@ -1729,12 +1708,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip("MXFP8 unsupported for grouped linear.") pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
if recipe.float8_block_scaling():
pytest.skip("Float8 block scaling unsupported for grouped linear.")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
......
...@@ -553,14 +553,10 @@ def test_sanity_grouped_linear( ...@@ -553,14 +553,10 @@ def test_sanity_grouped_linear(
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8():
pytest.skip("Grouped linear does not support MXFP8")
if fp8_recipe.float8_current_scaling():
pytest.skip("Grouped linear does not support FP8 current scaling")
if fp8_recipe.float8_block_scaling():
pytest.skip("Grouped linear does not support FP8 block scaling")
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
......
...@@ -9,10 +9,9 @@ import os ...@@ -9,10 +9,9 @@ import os
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec, get_sm_count from ..utils import get_sm_count
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
...@@ -174,14 +173,6 @@ def general_grouped_gemm( ...@@ -174,14 +173,6 @@ def general_grouped_gemm(
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if isinstance(A[0], Float8TensorBase):
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a._data)
assert_dim_for_fp8_exec(b._data)
empty_tensor = _empty_tensor() empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms empty_tensors = [empty_tensor] * num_gemms
...@@ -208,6 +199,8 @@ def general_grouped_gemm( ...@@ -208,6 +199,8 @@ def general_grouped_gemm(
for o in out for o in out
] # this should differ with respect to single output ] # this should differ with respect to single output
# TODO: Move the swizzle to the C++ side. # pylint: disable=fixme
original_scale_inverses_list = [swizzle_inputs(A[i], B[i], layout) for i in range(num_gemms)]
bias = tex.te_general_grouped_gemm( bias = tex.te_general_grouped_gemm(
A, A,
transa, transa,
...@@ -227,5 +220,7 @@ def general_grouped_gemm( ...@@ -227,5 +220,7 @@ def general_grouped_gemm(
use_split_accumulator, use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))), sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
) )
for i in range(num_gemms):
reset_swizzled_inputs(A[i], B[i], original_scale_inverses_list[i])
return out, bias, gelu_input return out, bias, gelu_input
...@@ -101,18 +101,22 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -101,18 +101,22 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count); bool use_split_accumulator, int math_sm_count);
namespace transformer_engine::pytorch {
/*************************************************************************************************** /***************************************************************************************************
* Transpose * Transpose
**************************************************************************************************/ **************************************************************************************************/
std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list, std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::handle>> output_list, std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, std::vector<py::handle> quantizer_list,
transformer_engine::DType otype); transformer_engine::DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
std::optional<at::Tensor> output = std::nullopt); std::optional<at::Tensor> output = std::nullopt);
} // namespace transformer_engine::pytorch
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
/*************************************************************************************************** /***************************************************************************************************
......
...@@ -196,12 +196,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -196,12 +196,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"),
py::arg("zero_centered_gamma")); py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose", m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); "Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>()); py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax")); m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax"));
......
...@@ -109,6 +109,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor( ...@@ -109,6 +109,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
} }
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
opts = opts.dtype(torch::kFloat32); opts = opts.dtype(torch::kFloat32);
// TODO: Replace with an empty tensor.
at::Tensor scale_inv = at::reciprocal(scale); at::Tensor scale_inv = at::reciprocal(scale);
py::object ret; py::object ret;
if (internal) { if (internal) {
......
...@@ -6,27 +6,38 @@ ...@@ -6,27 +6,38 @@
#include <optional> #include <optional>
#include "ATen/core/TensorBody.h"
#include "extensions.h" #include "extensions.h"
#include "pybind.h"
std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list, namespace transformer_engine::pytorch {
std::optional<std::vector<py::handle>> output_list,
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, std::vector<py::handle> quantizer_list,
transformer_engine::DType otype) { transformer_engine::DType otype) {
using namespace transformer_engine::pytorch; init_extension();
std::vector<NVTETensor> nvte_tensor_input_list; std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list; std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<py::object> py_output_objects_list; std::vector<py::object> py_output_objects_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers; std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto none = py::none(); if (output_list.has_value()) {
py_output_objects_list = output_list.value();
}
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool with_fused_kernel = true;
// create TE tensors from input // create TE tensors from input
for (size_t i = 0; i < input_list.size(); i++) { for (size_t i = 0; i < input_list.size(); i++) {
auto input_tensor = makeTransformerEngineTensor(input_list[i], none); auto input_tensor = makeTransformerEngineTensor(input_list[i]);
const NVTEShape input_shape = input_tensor.shape(); const NVTEShape input_shape = input_tensor.shape();
transformer_engine::TensorWrapper output_tensor; transformer_engine::TensorWrapper output_tensor;
if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) {
with_fused_kernel = false;
}
if (output_list == std::nullopt) { if (output_list == std::nullopt) {
std::unique_ptr<Quantizer> quantizer = convert_quantizer(quantizer_list[i]); std::unique_ptr<Quantizer> quantizer = convert_quantizer(quantizer_list[i]);
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim); std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
...@@ -48,16 +59,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list, ...@@ -48,16 +59,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(), NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(),
"Number of input and output tensors must match"); "Number of input and output tensors must match");
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool with_fused_kernel = true;
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
const auto& tensor = nvte_tensor_output_list[i]; if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) {
if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING) {
with_fused_kernel = false;
break;
}
if (nvte_tensor_columnwise_data(tensor) == nullptr) {
with_fused_kernel = false; with_fused_kernel = false;
break; break;
} }
...@@ -68,10 +71,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list, ...@@ -68,10 +71,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
} else { } else {
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { for (size_t i = 0; i < py_output_objects_list.size(); i++) {
// TODO: switch to nvte_quantize_v2 with advanced numerical options quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt);
nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i],
at::cuda::getCurrentCUDAStream());
} }
} }
return py_output_objects_list; return py_output_objects_list;
...@@ -79,7 +80,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list, ...@@ -79,7 +80,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
std::optional<at::Tensor> output) { std::optional<at::Tensor> output) {
using namespace transformer_engine::pytorch; init_extension();
const auto dim = input.dim(); const auto dim = input.dim();
NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose.");
...@@ -106,3 +107,5 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, ...@@ -106,3 +107,5 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
return out; return out;
} }
} // namespace transformer_engine::pytorch
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
"""FP8 Padding API""" """FP8 Padding API"""
from typing import Union, List from typing import List, Optional, Tuple
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..fp8 import FP8GlobalStateManager
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
...@@ -74,22 +75,30 @@ class Fp8Padding(torch.nn.Module): ...@@ -74,22 +75,30 @@ class Fp8Padding(torch.nn.Module):
---------- ----------
num_gemms: int num_gemms: int
number of GEMMs to be performed simutaneously. number of GEMMs to be performed simutaneously.
align_size: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
""" """
def __init__( def __init__(
self, self,
num_gemms, num_gemms: int,
align_size: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_gemms = num_gemms self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], m_splits: List[int],
) -> Union[torch.Tensor, List[int]]: ) -> Tuple[torch.Tensor, List[int]]:
""" """
Apply the padding to the input. Apply the padding to the input.
...@@ -104,7 +113,12 @@ class Fp8Padding(torch.nn.Module): ...@@ -104,7 +113,12 @@ class Fp8Padding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] padded_m_splits = [
(m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits
]
# no padding needed
if m_splits == padded_m_splits:
return inp, m_splits
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fn = _Fp8Padding.apply fn = _Fp8Padding.apply
......
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
"""FP8 Padding API""" """FP8 Padding API"""
from typing import List from typing import List, Optional
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..fp8 import FP8GlobalStateManager
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
...@@ -70,15 +71,23 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -70,15 +71,23 @@ class Fp8Unpadding(torch.nn.Module):
---------- ----------
num_gemms: int num_gemms: int
number of GEMMs to be performed simutaneously. number of GEMMs to be performed simutaneously.
align_size: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
""" """
def __init__( def __init__(
self, self,
num_gemms, num_gemms: int,
align_size: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_gemms = num_gemms self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
...@@ -100,7 +109,12 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -100,7 +109,12 @@ class Fp8Unpadding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] padded_m_splits = [
(m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits
]
# no padding needed
if m_splits == padded_m_splits:
return inp
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fn = _Fp8Unpadding.apply fn = _Fp8Unpadding.apply
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_multi_stream_cublas_workspace, get_multi_stream_cublas_workspace,
TransformerEngineBaseModule, TransformerEngineBaseModule,
...@@ -37,7 +38,6 @@ from ..cpp_extensions import ( ...@@ -37,7 +38,6 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..tensor.float8_tensor import Float8Tensor
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
...@@ -47,7 +47,6 @@ from ..tensor.quantized_tensor import ( ...@@ -47,7 +47,6 @@ from ..tensor.quantized_tensor import (
restore_from_saved, restore_from_saved,
) )
__all__ = ["GroupedLinear"] __all__ = ["GroupedLinear"]
...@@ -85,15 +84,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -85,15 +84,6 @@ class _GroupedLinear(torch.autograd.Function):
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
device = inp.device device = inp.device
# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("GroupedLinear does not yet support MXFP8")
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling")
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8Blockwise scaling")
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weights[0].shape[-1] in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
...@@ -126,7 +116,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -126,7 +116,11 @@ class _GroupedLinear(torch.autograd.Function):
for output_quantizer in output_quantizers: for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False) output_quantizer.set_usage(rowwise=True, columnwise=False)
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8: if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
inputmats = tex.fused_multi_quantize( inputmats = tex.fused_multi_quantize(
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype] inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
) )
...@@ -167,7 +161,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -167,7 +161,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits=m_splits, m_splits=m_splits,
bias=biases, bias=biases,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=fprop_gemm_use_split_accumulator,
) )
if fp8_calibration: if fp8_calibration:
...@@ -182,6 +176,16 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -182,6 +176,16 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weights_shape_1 = weights[0].shape[1] ctx.weights_shape_1 = weights[0].shape[1]
# TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad:
for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensor):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensor):
weight.update_usage(columnwise_usage=True)
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats, *inputmats,
*weights_fp8, *weights_fp8,
...@@ -202,6 +206,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -202,6 +206,7 @@ class _GroupedLinear(torch.autograd.Function):
ctx.num_gemms = num_gemms ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
...@@ -247,10 +252,17 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -247,10 +252,17 @@ class _GroupedLinear(torch.autograd.Function):
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.fp8: if ctx.fp8:
if ctx.use_bias: if ctx.use_bias:
for i in range(ctx.num_gemms): # unfuse bgrad for now until cast_transpose + dgrad calculation is ready
grad_biases[i], grad_output[i] = tex.bgrad_quantize( # for Float8BlockQuantizer.
grad_output_mats[i], ctx.grad_output_quantizers[i] if ctx.fp8_recipe.float8_block_scaling():
) for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i])
else:
for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output_mats[i], ctx.grad_output_quantizers[i]
)
else: else:
grad_output = tex.fused_multi_quantize( grad_output = tex.fused_multi_quantize(
grad_output_mats, grad_output_mats,
...@@ -269,6 +281,13 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -269,6 +281,13 @@ class _GroupedLinear(torch.autograd.Function):
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.requires_dgrad: if ctx.requires_dgrad:
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_dgrad.use_split_accumulator
)
dgrad = torch.empty( dgrad = torch.empty(
(sum(ctx.m_splits), ctx.weights_shape_1), (sum(ctx.m_splits), ctx.weights_shape_1),
dtype=ctx.activation_dtype, dtype=ctx.activation_dtype,
...@@ -285,10 +304,17 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -285,10 +304,17 @@ class _GroupedLinear(torch.autograd.Function):
layout="NN", layout="NN",
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
grad=True, grad=True,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=dgrad_gemm_use_split_accumulator,
) )
if ctx.weights_requires_grad: if ctx.weights_requires_grad:
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
if ctx.fuse_wgrad_accumulation: if ctx.fuse_wgrad_accumulation:
wgrad_list = main_grads wgrad_list = main_grads
else: else:
...@@ -308,7 +334,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -308,7 +334,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
use_bias=ctx.use_bias if grad_biases[0] is None else None, use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases, bias=biases,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
) )
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
...@@ -374,8 +400,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -374,8 +400,8 @@ class _GroupedLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None, # is_grad_enabled None,
None, # is_grad_enabled None,
*wgrad_list, *wgrad_list,
*grad_biases, *grad_biases,
) )
...@@ -425,6 +451,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -425,6 +451,9 @@ class GroupedLinear(TransformerEngineBaseModule):
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases.
The TP communication should be handled in the dispatch and combine stages of MoE models.
""" """
def __init__( def __init__(
...@@ -467,7 +496,11 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -467,7 +496,11 @@ class GroupedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
self._num_fp8_tensors_per_gemm = {
"fwd": 3,
"bwd": 2,
}
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -478,6 +511,12 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -478,6 +511,12 @@ class GroupedLinear(TransformerEngineBaseModule):
self.set_tensor_parallel_group(tp_group) self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp() self.set_nccl_overlap_warning_if_tp()
if self.tp_size > 1 and bias:
raise ValueError(
"GroupedLinear doesn't support bias when TP > 1. "
"Because the TP communication is handled outside of this module."
)
self.parallel_mode = parallel_mode self.parallel_mode = parallel_mode
assert ( assert (
self.parallel_mode in GemmParallelModes self.parallel_mode in GemmParallelModes
...@@ -504,7 +543,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -504,7 +543,7 @@ class GroupedLinear(TransformerEngineBaseModule):
), ),
init_fn=init_method, init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i, fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"],
) )
# Construct bias parameters if needed # Construct bias parameters if needed
...@@ -529,12 +568,18 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -529,12 +568,18 @@ class GroupedLinear(TransformerEngineBaseModule):
self.reset_parameters(defer_init=device == "meta") self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
# So it cannot be fused with the GEMM """Init scales and amaxes for fwd | bwd."""
if self.parallel_mode == "row" and self.apply_bias: super().set_meta_tensor(fwd, recipe)
self.gemm_bias_unfused_add = True
else: # customize quantizers based on each recipe & layer configs
self.gemm_bias_unfused_add = False recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
assert not self.tp_size > 1, (
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
self._customize_quantizers_float8_current_scaling(fwd, recipe)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init) super().reset_parameters(defer_init=defer_init)
...@@ -592,7 +637,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -592,7 +637,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced) produced)
""" """
assert not isinstance( assert not isinstance(
inp, Float8Tensor inp, QuantizedTensor
), "GroupedLinear doesn't support input tensor in FP8." ), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
...@@ -617,20 +662,27 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -617,20 +662,27 @@ class GroupedLinear(TransformerEngineBaseModule):
grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms
if self.fp8: if self.fp8:
input_quantizers = [ input_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["input"] + i] self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms) for i in range(self.num_gemms)
] ]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms): for i in range(self.num_gemms):
input_quantizers[i].internal = True input_quantizers[i].internal = False
weight_quantizers = [ weight_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["weight"] + i] self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms) for i in range(self.num_gemms)
] ]
for i in range(self.num_gemms): for i in range(self.num_gemms):
weight_quantizers[i].internal = True weight_quantizers[i].internal = True
if torch.is_grad_enabled(): if torch.is_grad_enabled():
grad_output_quantizers = [ grad_output_quantizers = [
self.quantizers["scaling_bwd"][self._offsets["input"] + i] self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms) for i in range(self.num_gemms)
] ]
for i in range(self.num_gemms): for i in range(self.num_gemms):
...@@ -645,7 +697,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -645,7 +697,7 @@ class GroupedLinear(TransformerEngineBaseModule):
args += ( args += (
inp, inp,
m_splits, m_splits,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
...@@ -665,17 +717,37 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -665,17 +717,37 @@ class GroupedLinear(TransformerEngineBaseModule):
) )
out = linear_fn(*args) out = linear_fn(*args)
if self.gemm_bias_unfused_add:
out_shape = out.shape
out = torch.cat(
[
o + cast_if_needed(b, self.activation_dtype)
for o, b in zip(
torch.split(out.view(-1, self.out_features), m_splits), bias_tensors
)
]
).view(out_shape)
if self.return_bias: if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out return out
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert (
recipe.float8_current_scaling()
), "current scaling recipe quantizer customization here"
if fwd:
for i in range(self.num_gemms):
# set configs about amax epsilon and power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
# also set weight quantizer with same amax_epsilon & power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
else:
for i in range(self.num_gemms):
# set grad_output_quantizer with amax epsilon and power_2_scale
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment