Unverified Commit 8b326059 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Reduce the CPU overheads of `GroupedLinear` (#1072)



* use fused_multi_cast_transpose
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix input being empty tensor
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* allocate output tensors in C++
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* simplify code
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* avoid cudaGetDriverEntryPoint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* reduce torch.Tensor() calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent fa4b866d
...@@ -1228,7 +1228,8 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False ...@@ -1228,7 +1228,8 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
m = config.seq_len // 16 m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 1,))).values.tolist() dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16 m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Python interface for GEMM extensions""" """Python interface for GEMM extensions"""
import functools
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -13,6 +14,12 @@ from ..utils import assert_dim_for_fp8_exec ...@@ -13,6 +14,12 @@ from ..utils import assert_dim_for_fp8_exec
__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] __all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor()
def fp8_gemm( def fp8_gemm(
A: torch.Tensor, A: torch.Tensor,
A_scale_inv: torch.Tensor, A_scale_inv: torch.Tensor,
...@@ -39,7 +46,7 @@ def fp8_gemm( ...@@ -39,7 +46,7 @@ def fp8_gemm(
) -> torch.Tensor: ) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs.""" """TN layout GEMM with fp8 inputs."""
empty_tensor = torch.Tensor() empty_tensor = _empty_tensor()
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None assert fp8_meta_tensor is not None and out_index is not None
assert_dim_for_fp8_exec(A) assert_dim_for_fp8_exec(A)
...@@ -195,7 +202,7 @@ def gemm( ...@@ -195,7 +202,7 @@ def gemm(
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
empty_tensor = torch.Tensor() empty_tensor = _empty_tensor()
fp8_index = -1 # dummy index fp8_index = -1 # dummy index
if out is None: if out is None:
...@@ -313,8 +320,8 @@ def grouped_gemm( ...@@ -313,8 +320,8 @@ def grouped_gemm(
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
num_gemms = len(A) num_gemms = len(A)
empty_tensor = torch.Tensor() empty_tensor = _empty_tensor()
empty_tensors = [torch.Tensor()] * num_gemms empty_tensors = [empty_tensor] * num_gemms
if gelu and not grad: if gelu and not grad:
gelu_input = [ gelu_input = [
...@@ -401,8 +408,8 @@ def fp8_grouped_gemm( ...@@ -401,8 +408,8 @@ def fp8_grouped_gemm(
""" """
num_gemms = len(A) num_gemms = len(A)
empty_tensor = torch.Tensor() empty_tensor = _empty_tensor()
empty_tensors = [torch.Tensor()] * num_gemms empty_tensors = [empty_tensor] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_offset is not None assert fp8_meta_tensor is not None and out_offset is not None
for a, b in zip(A, B): for a, b in zip(A, B):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Python interface for transpose extensions""" """Python interface for transpose extensions"""
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
...@@ -13,6 +13,7 @@ __all__ = [ ...@@ -13,6 +13,7 @@ __all__ = [
"fp8_cast_transpose_fused", "fp8_cast_transpose_fused",
"fp8_cast_transpose_bgrad_fused", "fp8_cast_transpose_bgrad_fused",
"fp8_cast_transpose_bgrad_dgelu_fused", "fp8_cast_transpose_bgrad_dgelu_fused",
"fp8_multi_cast_transpose_fused",
"fp8_transpose_bgrad_fused", "fp8_transpose_bgrad_fused",
] ]
...@@ -118,3 +119,25 @@ def fp8_cast_transpose_bgrad_dgelu_fused( ...@@ -118,3 +119,25 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
amax_offset=int(fp8_tensor), amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor), scale_inv_offset=int(fp8_tensor),
) )
def fp8_multi_cast_transpose_fused(
input_list: List[torch.Tensor],
fp8_meta_tensor: tex.FP8TensorMeta,
scale_indices: List[int],
amax_indices: List[int],
scale_inv_indices: List[int],
otype: tex.DType,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Cast + Transpose with FP8 output"""
return tex.fused_multi_cast_transpose_alloc(
input_list,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
scale_indices,
amax_indices,
scale_inv_indices,
otype,
)
...@@ -180,6 +180,11 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list, ...@@ -180,6 +180,11 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_inv_output_list, std::vector<at::Tensor> scale_inv_output_list,
transformer_engine::DType otype); transformer_engine::DType otype);
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> fused_multi_cast_transpose_alloc(
std::vector<at::Tensor> input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
std::vector<int> scale_indices, std::vector<int> amax_indices,
std::vector<int> scale_inv_indices, transformer_engine::DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype); at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype);
void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype); void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype);
......
...@@ -84,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -84,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>()); "Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc,
"Fused Multi-tensor Cast + Transpose with allocating output tensors",
py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard<py::gil_scoped_release>()); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
...@@ -75,7 +75,7 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, at::T ...@@ -75,7 +75,7 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, at::T
// Return immediately if tensors are empty // Return immediately if tensors are empty
if (M == 0 || N == 0) { if (M == 0 || N == 0) {
return {grad_bias, grad_output_cast, grad_output_transpose}; return {grad_bias.zero_(), grad_output_cast, grad_output_transpose};
} }
// Get pointers for FP8 scale, amax, scale-inverse // Get pointers for FP8 scale, amax, scale-inverse
...@@ -196,22 +196,21 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, ...@@ -196,22 +196,21 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
return {grad_bias, dgelu, dgelu_transpose}; return {grad_bias, dgelu, dgelu_transpose};
} }
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list, void fused_multi_cast_transpose_base(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list, std::vector<void*> scale_dptr_list,
std::vector<at::Tensor> cast_output_list, std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list, std::vector<at::Tensor> transposed_output_list,
std::vector<at::Tensor> amax_list, std::vector<void*> amax_dptr_list,
std::vector<at::Tensor> scale_inv_list, std::vector<void*> scale_inv_dptr_list,
transformer_engine::DType otype) { transformer_engine::DType otype) {
using namespace transformer_engine; using namespace transformer_engine;
// Extract properties from PyTorch tensors // Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, scale_dptr_list, cast_output_dptr_list, std::vector<void*> input_dptr_list, cast_output_dptr_list, transposed_output_dptr_list;
transposed_output_dptr_list, amax_dptr_list, scale_inv_dptr_list; std::vector<std::vector<size_t>> input_shape_list, cast_output_shape_list,
std::vector<std::vector<size_t>> input_shape_list, scale_shape_list, cast_output_shape_list, transposed_output_shape_list;
transposed_output_shape_list, amax_shape_list, scale_inv_shape_list; std::vector<transformer_engine::DType> input_type_list, cast_output_type_list,
std::vector<transformer_engine::DType> input_type_list, scale_type_list, cast_output_type_list, transposed_output_type_list;
transposed_output_type_list, amax_type_list, scale_inv_type_list;
auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, std::vector<void*>& dptr_list, auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list) { std::vector<std::vector<size_t>>& shape_list) {
dptr_list.push_back(tensor.data_ptr()); dptr_list.push_back(tensor.data_ptr());
...@@ -232,20 +231,14 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list, ...@@ -232,20 +231,14 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
}; };
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
extract_tensor_props(input_list[tensor_id], input_dptr_list, input_shape_list, input_type_list); extract_tensor_props(input_list[tensor_id], input_dptr_list, input_shape_list, input_type_list);
extract_tensor_props(scale_list[tensor_id], scale_dptr_list, scale_shape_list, scale_type_list);
extract_tensor_props_skip_dtype(cast_output_list[tensor_id], cast_output_dptr_list, extract_tensor_props_skip_dtype(cast_output_list[tensor_id], cast_output_dptr_list,
cast_output_shape_list); cast_output_shape_list);
cast_output_type_list.push_back(otype); cast_output_type_list.push_back(otype);
extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], transposed_output_dptr_list, extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], transposed_output_dptr_list,
transposed_output_shape_list); transposed_output_shape_list);
transposed_output_type_list.push_back(otype); transposed_output_type_list.push_back(otype);
extract_tensor_props(amax_list[tensor_id], amax_dptr_list, amax_shape_list, amax_type_list);
extract_tensor_props(scale_inv_list[tensor_id], scale_inv_dptr_list, scale_inv_shape_list,
scale_inv_type_list);
} }
transformer_engine::TensorWrapper workspace;
// Construct TE tensors // Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_cast_output_list, nvte_transposed_output_list; std::vector<NVTETensor> nvte_input_list, nvte_cast_output_list, nvte_transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers; std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
...@@ -257,6 +250,7 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list, ...@@ -257,6 +250,7 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
return tensor_wrappers.back().data(); return tensor_wrappers.back().data();
}; };
for (size_t i = 0; i < input_dptr_list.size(); ++i) { for (size_t i = 0; i < input_dptr_list.size(); ++i) {
if (input_dptr_list[i] == nullptr) continue;
nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], input_shape_list[i], nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], input_shape_list[i],
input_type_list[i], nullptr, nullptr, nullptr)); input_type_list[i], nullptr, nullptr, nullptr));
nvte_cast_output_list.emplace_back( nvte_cast_output_list.emplace_back(
...@@ -280,6 +274,55 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list, ...@@ -280,6 +274,55 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list,
std::vector<at::Tensor> amax_list,
std::vector<at::Tensor> scale_inv_list,
transformer_engine::DType otype) {
std::vector<void*> scale_dptr_list, amax_dptr_list, scale_inv_dptr_list;
for (size_t i = 0; i < scale_list.size(); ++i) {
scale_dptr_list.push_back(scale_list[i].data_ptr());
amax_dptr_list.push_back(amax_list[i].data_ptr());
scale_inv_dptr_list.push_back(scale_inv_list[i].data_ptr());
}
fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list,
transposed_output_list, amax_dptr_list, scale_inv_dptr_list,
otype);
}
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> fused_multi_cast_transpose_alloc(
std::vector<at::Tensor> input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
std::vector<int> scale_indices, std::vector<int> amax_indices,
std::vector<int> scale_inv_indices, transformer_engine::DType otype) {
using namespace transformer_engine;
std::vector<at::Tensor> cast_output_list;
std::vector<at::Tensor> transposed_output_list;
std::vector<void*> scale_dptr_list, amax_dptr_list, scale_inv_dptr_list;
for (size_t i = 0; i < input_list.size(); ++i) {
auto input_i = input_list[i];
// construct cast output tensors
auto cast_output_i = allocateTorchTensor(input_i.size(0), input_i.size(1), DType::kByte);
cast_output_list.push_back(cast_output_i);
// construct transposed output tensors
auto transposed_output_i = allocateTorchTensor(input_i.size(1), input_i.size(0), DType::kByte);
transposed_output_list.push_back(transposed_output_i);
// construct amax/scale/scale_inv dptr lists
amax_dptr_list.push_back(getDataPtr(amax, amax_indices[i]));
scale_dptr_list.push_back(getDataPtr(scale, scale_indices[i]));
scale_inv_dptr_list.push_back(getDataPtr(scale_inv, scale_inv_indices[i]));
}
fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list,
transposed_output_list, amax_dptr_list, scale_inv_dptr_list,
otype);
return std::make_tuple(std::move(cast_output_list), std::move(transposed_output_list));
}
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) { at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) {
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -258,7 +258,8 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te ...@@ -258,7 +258,8 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te
// Set an external SM Margin to all the GEMMs. // Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs // This comes in handy when DP is overlapped with GEMMs
const int sm_count = transformer_engine::cuda::sm_count(); const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count); int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
...@@ -293,7 +294,8 @@ std::vector<at::Tensor> te_grouped_gemm_ts( ...@@ -293,7 +294,8 @@ std::vector<at::Tensor> te_grouped_gemm_ts(
// Set an external SM Margin to all the GEMMs. // Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs // This comes in handy when DP is overlapped with GEMMs
const int sm_count = transformer_engine::cuda::sm_count(); const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count); int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse, te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse,
......
...@@ -34,7 +34,7 @@ from ..distributed import ( ...@@ -34,7 +34,7 @@ from ..distributed import (
from ..cpp_extensions import ( from ..cpp_extensions import (
cast_to_fp8, cast_to_fp8,
fp8_cast_transpose_bgrad_fused, fp8_cast_transpose_bgrad_fused,
fp8_cast_transpose_fused, fp8_multi_cast_transpose_fused,
fp8_grouped_gemm, fp8_grouped_gemm,
grouped_gemm, grouped_gemm,
) )
...@@ -82,12 +82,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -82,12 +82,12 @@ class _GroupedLinear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
is_grad_enabled: bool, is_grad_enabled: bool,
weights_fp8: List[Union[Float8Tensor, None]],
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor: ) -> torch.Tensor:
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms] biases = weights_and_biases[num_gemms:]
biases = weights_and_biases[2 * num_gemms :]
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weights[0].shape[-1] in_features = weights[0].shape[-1]
...@@ -113,15 +113,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -113,15 +113,15 @@ class _GroupedLinear(torch.autograd.Function):
and not sequence_parallel and not sequence_parallel
): ):
# FP8 input for forward, FP8 input transpose for backward wgrad # FP8 input for forward, FP8 input transpose for backward wgrad
for i in range(num_gemms): indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms))
mat, mat_t = fp8_cast_transpose_fused( inputmats, inputmats_t = fp8_multi_cast_transpose_fused(
inputmats_no_fp8[i], inputmats_no_fp8,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
_GEMM_INPUT + i, indices, # scale_indices
fp8_dtype_forward, indices, # amax_indices
) indices, # scale_inv_indices
inputmats.append(mat) fp8_dtype_forward,
inputmats_t.append(mat_t) )
else: else:
# FP8 input for forward # FP8 input for forward
inputmats = [ inputmats = [
...@@ -308,13 +308,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -308,13 +308,15 @@ class _GroupedLinear(torch.autograd.Function):
) )
else: else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
for i in range(ctx.num_gemms): indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms))
grad_output_c[i], grad_output_t[i] = fp8_cast_transpose_fused( grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused(
grad_output_mats[i], grad_output_mats,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
_GRAD_OUTPUT + i, indices, # scale_indices
fp8_dtype_backward, indices, # amax_indices
) indices, # scale_inv_indices
fp8_dtype_backward,
)
else: else:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
grad_output_c[i] = cast_to_fp8( grad_output_c[i] = cast_to_fp8(
...@@ -334,7 +336,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -334,7 +336,7 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
dgrad = torch.empty( dgrad = torch.empty(
(sum(ctx.m_splits), weights_fp8[i].size(1)), (sum(ctx.m_splits), weights_fp8[0].size(1)),
dtype=ctx.activation_dtype, dtype=ctx.activation_dtype,
device=grad_output.device, device=grad_output.device,
) )
...@@ -487,8 +489,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -487,8 +489,8 @@ class _GroupedLinear(torch.autograd.Function):
None, # activation_dtype None, # activation_dtype
None, # parallel_mode None, # parallel_mode
None, # is_grad_enabled None, # is_grad_enabled
None, # weights_fp8
*wgrad_list, *wgrad_list,
*([None] * ctx.num_gemms), # weights_fp8
*grad_biases, *grad_biases,
) )
...@@ -799,8 +801,8 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -799,8 +801,8 @@ class GroupedLinear(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
weight_tensors_fp8,
*weight_tensors, *weight_tensors,
*weight_tensors_fp8,
*bias_tensors, *bias_tensors,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment