Unverified Commit 8b326059 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Reduce the CPU overheads of `GroupedLinear` (#1072)



* use fused_multi_cast_transpose
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix input being empty tensor
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* allocate output tensors in C++
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* simplify code
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* avoid cudaGetDriverEntryPoint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* reduce torch.Tensor() calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent fa4b866d
......@@ -1228,7 +1228,8 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
inp_hidden_states.retain_grad()
m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 1,))).values.tolist()
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
......
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Python interface for GEMM extensions"""
import functools
from typing import Optional, Tuple, Union, List
import torch
import transformer_engine_torch as tex
......@@ -13,6 +14,12 @@ from ..utils import assert_dim_for_fp8_exec
__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor()
def fp8_gemm(
A: torch.Tensor,
A_scale_inv: torch.Tensor,
......@@ -39,7 +46,7 @@ def fp8_gemm(
) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs."""
empty_tensor = torch.Tensor()
empty_tensor = _empty_tensor()
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None
assert_dim_for_fp8_exec(A)
......@@ -195,7 +202,7 @@ def gemm(
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
empty_tensor = torch.Tensor()
empty_tensor = _empty_tensor()
fp8_index = -1 # dummy index
if out is None:
......@@ -313,8 +320,8 @@ def grouped_gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
if gelu and not grad:
gelu_input = [
......@@ -401,8 +408,8 @@ def fp8_grouped_gemm(
"""
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_offset is not None
for a, b in zip(A, B):
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Python interface for transpose extensions"""
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
from ..constants import TE_DType
......@@ -13,6 +13,7 @@ __all__ = [
"fp8_cast_transpose_fused",
"fp8_cast_transpose_bgrad_fused",
"fp8_cast_transpose_bgrad_dgelu_fused",
"fp8_multi_cast_transpose_fused",
"fp8_transpose_bgrad_fused",
]
......@@ -118,3 +119,25 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
)
def fp8_multi_cast_transpose_fused(
input_list: List[torch.Tensor],
fp8_meta_tensor: tex.FP8TensorMeta,
scale_indices: List[int],
amax_indices: List[int],
scale_inv_indices: List[int],
otype: tex.DType,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Cast + Transpose with FP8 output"""
return tex.fused_multi_cast_transpose_alloc(
input_list,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
scale_indices,
amax_indices,
scale_inv_indices,
otype,
)
......@@ -180,6 +180,11 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_inv_output_list,
transformer_engine::DType otype);
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> fused_multi_cast_transpose_alloc(
std::vector<at::Tensor> input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
std::vector<int> scale_indices, std::vector<int> amax_indices,
std::vector<int> scale_inv_indices, transformer_engine::DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype);
void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype);
......
......@@ -84,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc,
"Fused Multi-tensor Cast + Transpose with allocating output tensors",
py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8",
py::call_guard<py::gil_scoped_release>());
......
......@@ -75,7 +75,7 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, at::T
// Return immediately if tensors are empty
if (M == 0 || N == 0) {
return {grad_bias, grad_output_cast, grad_output_transpose};
return {grad_bias.zero_(), grad_output_cast, grad_output_transpose};
}
// Get pointers for FP8 scale, amax, scale-inverse
......@@ -196,22 +196,21 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
return {grad_bias, dgelu, dgelu_transpose};
}
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list,
std::vector<at::Tensor> amax_list,
std::vector<at::Tensor> scale_inv_list,
transformer_engine::DType otype) {
void fused_multi_cast_transpose_base(std::vector<at::Tensor> input_list,
std::vector<void*> scale_dptr_list,
std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list,
std::vector<void*> amax_dptr_list,
std::vector<void*> scale_inv_dptr_list,
transformer_engine::DType otype) {
using namespace transformer_engine;
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, scale_dptr_list, cast_output_dptr_list,
transposed_output_dptr_list, amax_dptr_list, scale_inv_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, scale_shape_list, cast_output_shape_list,
transposed_output_shape_list, amax_shape_list, scale_inv_shape_list;
std::vector<transformer_engine::DType> input_type_list, scale_type_list, cast_output_type_list,
transposed_output_type_list, amax_type_list, scale_inv_type_list;
std::vector<void*> input_dptr_list, cast_output_dptr_list, transposed_output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, cast_output_shape_list,
transposed_output_shape_list;
std::vector<transformer_engine::DType> input_type_list, cast_output_type_list,
transposed_output_type_list;
auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list) {
dptr_list.push_back(tensor.data_ptr());
......@@ -232,20 +231,14 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
};
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
extract_tensor_props(input_list[tensor_id], input_dptr_list, input_shape_list, input_type_list);
extract_tensor_props(scale_list[tensor_id], scale_dptr_list, scale_shape_list, scale_type_list);
extract_tensor_props_skip_dtype(cast_output_list[tensor_id], cast_output_dptr_list,
cast_output_shape_list);
cast_output_type_list.push_back(otype);
extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], transposed_output_dptr_list,
transposed_output_shape_list);
transposed_output_type_list.push_back(otype);
extract_tensor_props(amax_list[tensor_id], amax_dptr_list, amax_shape_list, amax_type_list);
extract_tensor_props(scale_inv_list[tensor_id], scale_inv_dptr_list, scale_inv_shape_list,
scale_inv_type_list);
}
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_cast_output_list, nvte_transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
......@@ -257,6 +250,7 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
if (input_dptr_list[i] == nullptr) continue;
nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], input_shape_list[i],
input_type_list[i], nullptr, nullptr, nullptr));
nvte_cast_output_list.emplace_back(
......@@ -280,6 +274,55 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
at::cuda::getCurrentCUDAStream());
}
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list,
std::vector<at::Tensor> amax_list,
std::vector<at::Tensor> scale_inv_list,
transformer_engine::DType otype) {
std::vector<void*> scale_dptr_list, amax_dptr_list, scale_inv_dptr_list;
for (size_t i = 0; i < scale_list.size(); ++i) {
scale_dptr_list.push_back(scale_list[i].data_ptr());
amax_dptr_list.push_back(amax_list[i].data_ptr());
scale_inv_dptr_list.push_back(scale_inv_list[i].data_ptr());
}
fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list,
transposed_output_list, amax_dptr_list, scale_inv_dptr_list,
otype);
}
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> fused_multi_cast_transpose_alloc(
std::vector<at::Tensor> input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
std::vector<int> scale_indices, std::vector<int> amax_indices,
std::vector<int> scale_inv_indices, transformer_engine::DType otype) {
using namespace transformer_engine;
std::vector<at::Tensor> cast_output_list;
std::vector<at::Tensor> transposed_output_list;
std::vector<void*> scale_dptr_list, amax_dptr_list, scale_inv_dptr_list;
for (size_t i = 0; i < input_list.size(); ++i) {
auto input_i = input_list[i];
// construct cast output tensors
auto cast_output_i = allocateTorchTensor(input_i.size(0), input_i.size(1), DType::kByte);
cast_output_list.push_back(cast_output_i);
// construct transposed output tensors
auto transposed_output_i = allocateTorchTensor(input_i.size(1), input_i.size(0), DType::kByte);
transposed_output_list.push_back(transposed_output_i);
// construct amax/scale/scale_inv dptr lists
amax_dptr_list.push_back(getDataPtr(amax, amax_indices[i]));
scale_dptr_list.push_back(getDataPtr(scale, scale_indices[i]));
scale_inv_dptr_list.push_back(getDataPtr(scale_inv, scale_inv_indices[i]));
}
fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list,
transposed_output_list, amax_dptr_list, scale_inv_dptr_list,
otype);
return std::make_tuple(std::move(cast_output_list), std::move(transposed_output_list));
}
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) {
using namespace transformer_engine;
......
......@@ -258,7 +258,8 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int sm_count = transformer_engine::cuda::sm_count();
const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
......@@ -293,7 +294,8 @@ std::vector<at::Tensor> te_grouped_gemm_ts(
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int sm_count = transformer_engine::cuda::sm_count();
const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse,
......
......@@ -34,7 +34,7 @@ from ..distributed import (
from ..cpp_extensions import (
cast_to_fp8,
fp8_cast_transpose_bgrad_fused,
fp8_cast_transpose_fused,
fp8_multi_cast_transpose_fused,
fp8_grouped_gemm,
grouped_gemm,
)
......@@ -82,12 +82,12 @@ class _GroupedLinear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
weights_fp8: List[Union[Float8Tensor, None]],
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
biases = weights_and_biases[2 * num_gemms :]
biases = weights_and_biases[num_gemms:]
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
......@@ -113,15 +113,15 @@ class _GroupedLinear(torch.autograd.Function):
and not sequence_parallel
):
# FP8 input for forward, FP8 input transpose for backward wgrad
for i in range(num_gemms):
mat, mat_t = fp8_cast_transpose_fused(
inputmats_no_fp8[i],
fp8_meta["scaling_fwd"],
_GEMM_INPUT + i,
fp8_dtype_forward,
)
inputmats.append(mat)
inputmats_t.append(mat_t)
indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms))
inputmats, inputmats_t = fp8_multi_cast_transpose_fused(
inputmats_no_fp8,
fp8_meta["scaling_fwd"],
indices, # scale_indices
indices, # amax_indices
indices, # scale_inv_indices
fp8_dtype_forward,
)
else:
# FP8 input for forward
inputmats = [
......@@ -308,13 +308,15 @@ class _GroupedLinear(torch.autograd.Function):
)
else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
for i in range(ctx.num_gemms):
grad_output_c[i], grad_output_t[i] = fp8_cast_transpose_fused(
grad_output_mats[i],
ctx.fp8_meta["scaling_bwd"],
_GRAD_OUTPUT + i,
fp8_dtype_backward,
)
indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms))
grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused(
grad_output_mats,
ctx.fp8_meta["scaling_bwd"],
indices, # scale_indices
indices, # amax_indices
indices, # scale_inv_indices
fp8_dtype_backward,
)
else:
for i in range(ctx.num_gemms):
grad_output_c[i] = cast_to_fp8(
......@@ -334,7 +336,7 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
dgrad = torch.empty(
(sum(ctx.m_splits), weights_fp8[i].size(1)),
(sum(ctx.m_splits), weights_fp8[0].size(1)),
dtype=ctx.activation_dtype,
device=grad_output.device,
)
......@@ -487,8 +489,8 @@ class _GroupedLinear(torch.autograd.Function):
None, # activation_dtype
None, # parallel_mode
None, # is_grad_enabled
None, # weights_fp8
*wgrad_list,
*([None] * ctx.num_gemms), # weights_fp8
*grad_biases,
)
......@@ -799,8 +801,8 @@ class GroupedLinear(TransformerEngineBaseModule):
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
weight_tensors_fp8,
*weight_tensors,
*weight_tensors_fp8,
*bias_tensors,
)
out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment