Unverified Commit 047a5072 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Propagate fp8 scale-inverse modification to `GroupedLinear` (#1128)



* propagate scale_inv modification to GroupedLinear
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* optimization for separate scale_inv of weights and single output
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* let grouped gemm support different input combinations
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix type
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add contiguous check
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* use len() instead of isinstance
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix ut
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent bdea56fc
......@@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
)
inp_hidden_states.retain_grad()
if num_gemms > 1:
m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
else:
m_splits = torch.tensor([config.seq_len])
with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear):
......@@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy(
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
"""Split the tests to reduce CI time"""
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
......@@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
)
def test_grouped_linear_accuracy_single_gemm():
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=1,
bs=2,
model=list(model_configs.keys())[0],
fp8=True,
fp8_model_params=True,
)
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
......@@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
fp8_grouped_gemm(
A_fp8,
scale_inv,
[scale_inv],
0, # A_offset
tex.DType.kFloat8E4M3,
B_fp8,
......
......@@ -11,7 +11,12 @@ from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec
__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
__all__ = [
"gemm",
"fp8_gemm",
"grouped_gemm",
"fp8_grouped_gemm",
]
@functools.lru_cache(maxsize=None)
......@@ -313,7 +318,7 @@ def grouped_gemm(
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
) -> Tuple[List[torch.Tensor], ...]:
"""Non FP8 Grouped GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
......@@ -380,7 +385,7 @@ def grouped_gemm(
def fp8_grouped_gemm(
A: List[torch.Tensor],
A_scale_inv: torch.Tensor,
A_scale_inv: List[torch.Tensor],
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
......@@ -390,6 +395,7 @@ def fp8_grouped_gemm(
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
m_splits: Optional[List[int]] = None,
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
......@@ -398,16 +404,25 @@ def fp8_grouped_gemm(
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
) -> Tuple[List[torch.Tensor], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor.
scale: [ ...A_scale... | ...B_scale... | ...out_scale...]
scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...]
amax: [ ...A_amax... | ...B_amax... | ...out_amax...]
Input requirements:
1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None.
This is used for the calculation of output (fwd) and dgrad (bwd).
2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the
calculation of wgrad.
"""
num_gemms = len(A)
if num_gemms > 1 and len(A_scale_inv) == num_gemms:
assert len(out) == 1 and m_splits is not None
elif num_gemms > 1 and len(A_scale_inv) == 1:
assert len(out) == num_gemms
elif num_gemms == 1:
assert len(A_scale_inv) == 1 and len(out) == 1
else:
raise ValueError("Invalid input combinations of A_scale_inv and out.")
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
......@@ -420,20 +435,20 @@ def fp8_grouped_gemm(
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
bias_dtype = TE_DType[bias_dtype]
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
if len(A_scale_inv) == 1:
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_scale_inv[0],
A_fp8_tensor_offset,
A_dtype,
True, # transa
......@@ -456,5 +471,35 @@ def fp8_grouped_gemm(
accumulate,
use_split_accumulator,
)
else:
if gelu:
gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits]
torch.ops.tex_ts.te_grouped_gemm_single_output_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
m_splits,
out[0],
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
return out, gelu_input
......@@ -175,6 +175,7 @@ def fp8_multi_cast_transpose_fused(
amax_indices: List[int],
scale_inv_indices: List[int],
otype: tex.DType,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Cast + Transpose with FP8 output"""
......@@ -182,7 +183,7 @@ def fp8_multi_cast_transpose_fused(
input_list,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv,
scale_indices,
amax_indices,
scale_inv_indices,
......
......@@ -165,6 +165,16 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
/***************************************************************************************************
* Transpose
**************************************************************************************************/
......
......@@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
NVTE_CHECK(D.is_contiguous(), "D must be contiguous.");
void* d_i_ptr = reinterpret_cast<void*>(D.data_ptr());
for (size_t i = 0; i < A.size(); i++) {
if (m_splits[i] == 0) continue;
NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous.");
NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous.");
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
d_i_ptr, {static_cast<size_t>(m_splits[i]), static_cast<size_t>(A[i].size(0))}, D_type,
getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
// Move the D pointer to the next split.
char* char_ptr = reinterpret_cast<char*>(d_i_ptr);
char_ptr += m_splits[i] * A[i].size(0) * D.element_size();
d_i_ptr = reinterpret_cast<void*>(char_ptr);
}
for (size_t i = 0; i < workspace.size(); i++) {
te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
nullptr, nullptr, nullptr));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
......@@ -305,6 +305,41 @@ std::vector<at::Tensor> te_grouped_gemm_ts(
return D;
}
at::Tensor te_grouped_gemm_single_output_ts(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int64_t A_offset,
int64_t A_type, int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse,
int64_t B_offset, int64_t B_type, int64_t transb, std::vector<int64_t> m_splits, at::Tensor D,
int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, int64_t bias_type, std::vector<at::Tensor> pre_gelu_out,
int64_t grad, std::vector<at::Tensor> workspace, int64_t workspaceSize, int64_t accumulate,
int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B,
B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D,
D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg,
pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
......@@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
......
......@@ -42,6 +42,7 @@ from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
__all__ = ["GroupedLinear"]
......@@ -102,10 +103,12 @@ class _GroupedLinear(torch.autograd.Function):
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
inputmats_t = []
inputmat_scale_inv = None
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device)
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
......@@ -121,6 +124,7 @@ class _GroupedLinear(torch.autograd.Function):
indices, # amax_indices
indices, # scale_inv_indices
fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
)
else:
# FP8 input for forward
......@@ -130,9 +134,22 @@ class _GroupedLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"],
_GEMM_INPUT + i,
fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
)
for i in range(num_gemms)
]
# Hack for ONNX export
# Note: ONNX models are represented as a graph of tensor
# operations, so the in-place scale-inv update doesn't fit
# very well. We work around this by making it look like
# the scale-inv tensor is initialized with a copy.
# Note: ONNX export expects FP8 scales can be represented
# with constant ops. However, copying into a buffer
# involves an expand op for array broadcasting. We work
# around this by filling the buffer instead.
if is_in_onnx_export_mode():
inputmat_scale_inv.fill_(inputmat_scale_inv.item())
else:
inputmats = inputmats_no_fp8
......@@ -153,16 +170,17 @@ class _GroupedLinear(torch.autograd.Function):
_ = fp8_grouped_gemm(
[w._data for w in weights_fp8],
fp8_meta["scaling_fwd"].scale_inv,
_GEMM_WEIGHT,
[w._scale_inv for w in weights_fp8],
0, # weight offset is 0 for the newly created _scale_inv
fp8_dtype_forward,
inputmats,
fp8_meta["scaling_fwd"].scale_inv,
_GEMM_INPUT,
inputmat_scale_inv,
0,
fp8_dtype_forward,
torch.split(out, m_splits),
[out],
activation_dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits,
bias=biases,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
......@@ -230,7 +248,7 @@ class _GroupedLinear(torch.autograd.Function):
t.activation_offloading = True
ctx.save_for_backward(
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
inputmat_scale_inv,
*saved_inputmats,
*saved_inputmats_t,
*weights,
......@@ -270,7 +288,7 @@ class _GroupedLinear(torch.autograd.Function):
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
fwd_scale_inverses,
inputmat_scale_inv,
*saved_tensors,
) = ctx.saved_tensors
inputmats = saved_tensors[: ctx.num_gemms]
......@@ -342,18 +360,17 @@ class _GroupedLinear(torch.autograd.Function):
)
fp8_grouped_gemm(
[w.transpose_2d() for w in weights_fp8],
torch.cat(
[w._scale_inv for w in weights_fp8]
), # avoiding torch.cat requires another interface
[w._scale_inv for w in weights_fp8],
0, # weight offset is 0 for the newly created _scale_inv
weights_fp8[0]._fp8_dtype,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT,
fp8_dtype_backward,
torch.split(dgrad, ctx.m_splits),
[dgrad],
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
m_splits=ctx.m_splits,
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
......@@ -396,8 +413,8 @@ class _GroupedLinear(torch.autograd.Function):
inp._data if isinstance(inp, Float8Tensor) else inp
for inp in inputmats_t
],
fwd_scale_inverses,
_GEMM_INPUT,
[inputmat_scale_inv],
0,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment