Unverified Commit 047a5072 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Propagate fp8 scale-inverse modification to `GroupedLinear` (#1128)



* propagate scale_inv modification to GroupedLinear
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* optimization for separate scale_inv of weights and single output
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* let grouped gemm support different input combinations
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix type
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add contiguous check
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* use len() instead of isinstance
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix ut
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent bdea56fc
...@@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False ...@@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
m = config.seq_len // 16 if num_gemms > 1:
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() m = config.seq_len // 16
dist.append(dist[-1]) # Manually add a zero dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) dist.append(dist[-1]) # Manually add a zero
m_splits = m_splits * 16 m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
else:
m_splits = torch.tensor([config.seq_len])
with fp8_autocast(enabled=fp8): with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear): if isinstance(block, GroupedLinear):
...@@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy( ...@@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy(
@pytest.mark.parametrize("parallel_mode", ["column", "row"]) @pytest.mark.parametrize("parallel_mode", ["column", "row"])
def test_grouped_linear_accuracy_parallel_mode(parallel_mode): def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
"""Split the tests to reduce CI time""" """Split the tests to save CI time"""
test_grouped_linear_accuracy( test_grouped_linear_accuracy(
dtype=torch.float32, dtype=torch.float32,
num_gemms=6, num_gemms=6,
...@@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ...@@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
) )
def test_grouped_linear_accuracy_single_gemm():
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=1,
bs=2,
model=list(model_configs.keys())[0],
fp8=True,
fp8_model_params=True,
)
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
...@@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): ...@@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
fp8_grouped_gemm( fp8_grouped_gemm(
A_fp8, A_fp8,
scale_inv, [scale_inv],
0, # A_offset 0, # A_offset
tex.DType.kFloat8E4M3, tex.DType.kFloat8E4M3,
B_fp8, B_fp8,
......
...@@ -11,7 +11,12 @@ from ..constants import TE_DType ...@@ -11,7 +11,12 @@ from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec from ..utils import assert_dim_for_fp8_exec
__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] __all__ = [
"gemm",
"fp8_gemm",
"grouped_gemm",
"fp8_grouped_gemm",
]
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -313,7 +318,7 @@ def grouped_gemm( ...@@ -313,7 +318,7 @@ def grouped_gemm(
layout: str = "TN", layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None, bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False, use_bias: bool = False,
) -> Tuple[Union[List[torch.Tensor], None], ...]: ) -> Tuple[List[torch.Tensor], ...]:
"""Non FP8 Grouped GEMM.""" """Non FP8 Grouped GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
...@@ -380,7 +385,7 @@ def grouped_gemm( ...@@ -380,7 +385,7 @@ def grouped_gemm(
def fp8_grouped_gemm( def fp8_grouped_gemm(
A: List[torch.Tensor], A: List[torch.Tensor],
A_scale_inv: torch.Tensor, A_scale_inv: List[torch.Tensor],
A_fp8_tensor_offset: int, A_fp8_tensor_offset: int,
A_dtype: tex.DType, A_dtype: tex.DType,
B: List[torch.Tensor], B: List[torch.Tensor],
...@@ -390,6 +395,7 @@ def fp8_grouped_gemm( ...@@ -390,6 +395,7 @@ def fp8_grouped_gemm(
out: List[torch.Tensor], out: List[torch.Tensor],
out_dtype: torch.dtype, out_dtype: torch.dtype,
workspaces: List[torch.Tensor], workspaces: List[torch.Tensor],
m_splits: Optional[List[int]] = None,
out_offset: Optional[int] = None, out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None, fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False, gelu: bool = False,
...@@ -398,16 +404,25 @@ def fp8_grouped_gemm( ...@@ -398,16 +404,25 @@ def fp8_grouped_gemm(
use_bias: bool = False, use_bias: bool = False,
use_split_accumulator: bool = False, use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None, D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]: ) -> Tuple[List[torch.Tensor], ...]:
""" """
TN layout Grouped GEMM with fp8 inputs. TN layout Grouped GEMM with fp8 inputs.
This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor. Input requirements:
scale: [ ...A_scale... | ...B_scale... | ...out_scale...] 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None.
scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...] This is used for the calculation of output (fwd) and dgrad (bwd).
amax: [ ...A_amax... | ...B_amax... | ...out_amax...] 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the
calculation of wgrad.
""" """
num_gemms = len(A) num_gemms = len(A)
if num_gemms > 1 and len(A_scale_inv) == num_gemms:
assert len(out) == 1 and m_splits is not None
elif num_gemms > 1 and len(A_scale_inv) == 1:
assert len(out) == num_gemms
elif num_gemms == 1:
assert len(A_scale_inv) == 1 and len(out) == 1
else:
raise ValueError("Invalid input combinations of A_scale_inv and out.")
empty_tensor = _empty_tensor() empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms empty_tensors = [empty_tensor] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
...@@ -420,41 +435,71 @@ def fp8_grouped_gemm( ...@@ -420,41 +435,71 @@ def fp8_grouped_gemm(
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype] bias_dtype = TE_DType[bias_dtype]
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
torch.ops.tex_ts.te_grouped_gemm_ts( if len(A_scale_inv) == 1:
A, if gelu:
A_scale_inv, gelu_input = [
A_fp8_tensor_offset, torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
A_dtype, for o in out
True, # transa ]
B,
B_scale_inv, torch.ops.tex_ts.te_grouped_gemm_ts(
B_fp8_tensor_offset, A,
B_dtype, A_scale_inv[0],
False, # transb A_fp8_tensor_offset,
out, A_dtype,
0 if out_offset is None else out_offset, True, # transa
empty_tensor if out_offset is None else fp8_meta_tensor.scale, B,
out_dtype, B_scale_inv,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, B_fp8_tensor_offset,
bias if use_bias else empty_tensors, B_dtype,
bias_dtype, False, # transb
gelu_input, # this is pre_gelu_out out,
False, # grad 0 if out_offset is None else out_offset,
workspaces, empty_tensor if out_offset is None else fp8_meta_tensor.scale,
workspaces[0].shape[0], out_dtype,
accumulate, empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
use_split_accumulator, bias if use_bias else empty_tensors,
) bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
else:
if gelu:
gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits]
torch.ops.tex_ts.te_grouped_gemm_single_output_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
m_splits,
out[0],
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
return out, gelu_input return out, gelu_input
...@@ -175,6 +175,7 @@ def fp8_multi_cast_transpose_fused( ...@@ -175,6 +175,7 @@ def fp8_multi_cast_transpose_fused(
amax_indices: List[int], amax_indices: List[int],
scale_inv_indices: List[int], scale_inv_indices: List[int],
otype: tex.DType, otype: tex.DType,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Cast + Transpose with FP8 output""" """Cast + Transpose with FP8 output"""
...@@ -182,7 +183,7 @@ def fp8_multi_cast_transpose_fused( ...@@ -182,7 +183,7 @@ def fp8_multi_cast_transpose_fused(
input_list, input_list,
fp8_meta_tensor.scale, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv,
scale_indices, scale_indices,
amax_indices, amax_indices,
scale_inv_indices, scale_inv_indices,
......
...@@ -165,6 +165,16 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int ...@@ -165,6 +165,16 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count); bool use_split_accumulator, int math_sm_count);
void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
/*************************************************************************************************** /***************************************************************************************************
* Transpose * Transpose
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int ...@@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
te_workspace.data(), accumulate, use_split_accumulator, te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream()); math_sm_count, at::cuda::getCurrentCUDAStream());
} }
void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
NVTE_CHECK(D.is_contiguous(), "D must be contiguous.");
void* d_i_ptr = reinterpret_cast<void*>(D.data_ptr());
for (size_t i = 0; i < A.size(); i++) {
if (m_splits[i] == 0) continue;
NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous.");
NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous.");
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
d_i_ptr, {static_cast<size_t>(m_splits[i]), static_cast<size_t>(A[i].size(0))}, D_type,
getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
// Move the D pointer to the next split.
char* char_ptr = reinterpret_cast<char*>(d_i_ptr);
char_ptr += m_splits[i] * A[i].size(0) * D.element_size();
d_i_ptr = reinterpret_cast<void*>(char_ptr);
}
for (size_t i = 0; i < workspace.size(); i++) {
te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
nullptr, nullptr, nullptr));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
...@@ -305,6 +305,41 @@ std::vector<at::Tensor> te_grouped_gemm_ts( ...@@ -305,6 +305,41 @@ std::vector<at::Tensor> te_grouped_gemm_ts(
return D; return D;
} }
at::Tensor te_grouped_gemm_single_output_ts(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int64_t A_offset,
int64_t A_type, int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse,
int64_t B_offset, int64_t B_type, int64_t transb, std::vector<int64_t> m_splits, at::Tensor D,
int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, int64_t bias_type, std::vector<at::Tensor> pre_gelu_out,
int64_t grad, std::vector<at::Tensor> workspace, int64_t workspaceSize, int64_t accumulate,
int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B,
B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D,
D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg,
pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale, const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
...@@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) { ...@@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("srelu_ts", &srelu_ts); m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts); m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
......
...@@ -42,6 +42,7 @@ from ..constants import GemmParallelModes, dist_group_type ...@@ -42,6 +42,7 @@ from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
__all__ = ["GroupedLinear"] __all__ = ["GroupedLinear"]
...@@ -102,10 +103,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -102,10 +103,12 @@ class _GroupedLinear(torch.autograd.Function):
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = [] inputmats = []
inputmats_t = [] inputmats_t = []
inputmat_scale_inv = None
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device)
if ( if (
not fp8_meta["recipe"].override_linear_precision.wgrad not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled and is_grad_enabled
...@@ -121,6 +124,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -121,6 +124,7 @@ class _GroupedLinear(torch.autograd.Function):
indices, # amax_indices indices, # amax_indices
indices, # scale_inv_indices indices, # scale_inv_indices
fp8_dtype_forward, fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
) )
else: else:
# FP8 input for forward # FP8 input for forward
...@@ -130,9 +134,22 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -130,9 +134,22 @@ class _GroupedLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
_GEMM_INPUT + i, _GEMM_INPUT + i,
fp8_dtype_forward, fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
) )
for i in range(num_gemms) for i in range(num_gemms)
] ]
# Hack for ONNX export
# Note: ONNX models are represented as a graph of tensor
# operations, so the in-place scale-inv update doesn't fit
# very well. We work around this by making it look like
# the scale-inv tensor is initialized with a copy.
# Note: ONNX export expects FP8 scales can be represented
# with constant ops. However, copying into a buffer
# involves an expand op for array broadcasting. We work
# around this by filling the buffer instead.
if is_in_onnx_export_mode():
inputmat_scale_inv.fill_(inputmat_scale_inv.item())
else: else:
inputmats = inputmats_no_fp8 inputmats = inputmats_no_fp8
...@@ -153,16 +170,17 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -153,16 +170,17 @@ class _GroupedLinear(torch.autograd.Function):
_ = fp8_grouped_gemm( _ = fp8_grouped_gemm(
[w._data for w in weights_fp8], [w._data for w in weights_fp8],
fp8_meta["scaling_fwd"].scale_inv, [w._scale_inv for w in weights_fp8],
_GEMM_WEIGHT, 0, # weight offset is 0 for the newly created _scale_inv
fp8_dtype_forward, fp8_dtype_forward,
inputmats, inputmats,
fp8_meta["scaling_fwd"].scale_inv, inputmat_scale_inv,
_GEMM_INPUT, 0,
fp8_dtype_forward, fp8_dtype_forward,
torch.split(out, m_splits), [out],
activation_dtype, activation_dtype,
get_multi_stream_cublas_workspace(), get_multi_stream_cublas_workspace(),
m_splits=m_splits,
bias=biases, bias=biases,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
...@@ -230,7 +248,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -230,7 +248,7 @@ class _GroupedLinear(torch.autograd.Function):
t.activation_offloading = True t.activation_offloading = True
ctx.save_for_backward( ctx.save_for_backward(
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, inputmat_scale_inv,
*saved_inputmats, *saved_inputmats,
*saved_inputmats_t, *saved_inputmats_t,
*weights, *weights,
...@@ -270,7 +288,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -270,7 +288,7 @@ class _GroupedLinear(torch.autograd.Function):
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_GroupedLinear_backward"): with torch.cuda.nvtx.range("_GroupedLinear_backward"):
( (
fwd_scale_inverses, inputmat_scale_inv,
*saved_tensors, *saved_tensors,
) = ctx.saved_tensors ) = ctx.saved_tensors
inputmats = saved_tensors[: ctx.num_gemms] inputmats = saved_tensors[: ctx.num_gemms]
...@@ -342,18 +360,17 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -342,18 +360,17 @@ class _GroupedLinear(torch.autograd.Function):
) )
fp8_grouped_gemm( fp8_grouped_gemm(
[w.transpose_2d() for w in weights_fp8], [w.transpose_2d() for w in weights_fp8],
torch.cat( [w._scale_inv for w in weights_fp8],
[w._scale_inv for w in weights_fp8]
), # avoiding torch.cat requires another interface
0, # weight offset is 0 for the newly created _scale_inv 0, # weight offset is 0 for the newly created _scale_inv
weights_fp8[0]._fp8_dtype, weights_fp8[0]._fp8_dtype,
grad_output_c, grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT, _GRAD_OUTPUT,
fp8_dtype_backward, fp8_dtype_backward,
torch.split(dgrad, ctx.m_splits), [dgrad],
ctx.activation_dtype, ctx.activation_dtype,
get_multi_stream_cublas_workspace(), get_multi_stream_cublas_workspace(),
m_splits=ctx.m_splits,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
) )
else: else:
...@@ -396,8 +413,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -396,8 +413,8 @@ class _GroupedLinear(torch.autograd.Function):
inp._data if isinstance(inp, Float8Tensor) else inp inp._data if isinstance(inp, Float8Tensor) else inp
for inp in inputmats_t for inp in inputmats_t
], ],
fwd_scale_inverses, [inputmat_scale_inv],
_GEMM_INPUT, 0,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_t, grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment