Unverified Commit a4e95e86 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common/PyTorch] Grouped GEMM via multi-stream cuBLAS (#853)



* GroupedGEMM via multi-stream cublas

* fix A/B is nullptr while D is not nullptr

* add fp8 grouped gemm

* register with TorchScript

* add the GroupedLinear layer

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarJiang Shao <jiangs@nvidia.com>
Co-authored-by: default avatarQi Zhang <qizhang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 85aeb903
...@@ -31,7 +31,8 @@ disable=too-many-locals, ...@@ -31,7 +31,8 @@ disable=too-many-locals,
global-variable-not-assigned, global-variable-not-assigned,
redefined-argument-from-local, redefined-argument-from-local,
line-too-long, line-too-long,
too-many-return-statements too-many-return-statements,
too-many-nested-blocks
[TYPECHECK] [TYPECHECK]
ignored-modules=torch ignored-modules=torch
......
...@@ -24,6 +24,7 @@ from transformer_engine.pytorch import ( ...@@ -24,6 +24,7 @@ from transformer_engine.pytorch import (
LayerNormLinear, LayerNormLinear,
LayerNormMLP, LayerNormMLP,
Linear, Linear,
GroupedLinear,
MultiheadAttention, MultiheadAttention,
RMSNorm, RMSNorm,
TransformerLayer, TransformerLayer,
...@@ -31,7 +32,9 @@ from transformer_engine.pytorch import ( ...@@ -31,7 +32,9 @@ from transformer_engine.pytorch import (
InferenceParams, InferenceParams,
) )
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
import transformer_engine_torch as tex
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -1211,6 +1214,99 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): ...@@ -1211,6 +1214,99 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 1,))).values.tolist()
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear):
m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist())
else:
out = torch.cat(
[
block[i](inp)
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
]
)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
outputs = _test_grouped_linear_accuracy(grouped_linear, num_gemms, bs, dtype, config, fp8)
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states() reset_rng_states()
...@@ -1563,3 +1659,157 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1563,3 +1659,157 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
# Check if the fully generated output matches the one generated incrementally # Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype]) assert_allclose(full_output, incremental_output, atol[dtype])
@pytest.mark.parametrize(
"shape",
[
(1, 127, 128, 512),
(8, 15, 128, 512),
(8, 1027, 128, 512),
(16, 10027, 128, 512),
],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
torch.manual_seed(0)
z, m, k, n = shape
dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
assert m_splits.sum() == m and len(m_splits) == z
m_splits = m_splits.tolist()
if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output
grad = False
elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # dgrad
grad = True
else: # layout == "NT"
A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
grad = True
out_ref = [o.clone() for o in out]
for i in range(z):
gemm(
A[i],
B[i],
dtype,
get_workspace(),
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)
grouped_gemm(
A,
B,
out,
dtype,
get_multi_stream_cublas_workspace(),
grad=grad,
accumulate=accumulate,
layout=layout,
)
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize(
"shape",
[
(1, 128, 128, 512),
(8, 1024, 128, 512),
(16, 4096, 128, 512),
],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
z, m, k, n = shape
m_splits = m // z
dtype = torch.bfloat16
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output
out_ref = [o.clone() for o in out]
# fp8 should be robust enough to this fake scale
scale = 1 + torch.rand(z * 3, dtype=torch.float32, device="cuda")
scale_inv = 1 / scale
amax = torch.zeros(1024, z * 3, dtype=torch.float32, device="cuda")
A_fp8 = [
torch.ops.tex_ts.cast_to_fp8_ts(
A[i],
scale,
amax,
scale_inv,
i, # fp8 meta tensor index
tex.DType.kFloat8E4M3,
)
for i in range(z)
]
B_fp8 = [
torch.ops.tex_ts.cast_to_fp8_ts(
B[i],
scale,
amax,
scale_inv,
z + i, # fp8 meta tensor index
fp8_dtype,
)
for i in range(z)
]
fp8_grouped_gemm(
A_fp8,
scale_inv,
0, # A_offset
tex.DType.kFloat8E4M3,
B_fp8,
scale_inv,
z, # B_offset
fp8_dtype,
out,
dtype,
get_multi_stream_cublas_workspace(),
accumulate=accumulate,
)
# baseline
for i in range(z):
fp8_gemm(
A_fp8[i],
scale_inv,
i,
tex.DType.kFloat8E4M3,
B_fp8[i],
scale_inv,
z + i,
fp8_dtype,
dtype,
get_workspace(),
out=out_ref[i],
accumulate=accumulate,
)
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <cstdint> #include <cstdint>
#include <mutex>
#include "../common.h" #include "../common.h"
#include "../util/logging.h" #include "../util/logging.h"
...@@ -208,6 +209,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -208,6 +209,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
if (counter != nullptr) { if (counter != nullptr) {
if (m_split == 0) m_split = 1; if (m_split == 0) m_split = 1;
...@@ -275,6 +277,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -275,6 +277,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
} }
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];
// Warning: only call once per device!
static void init_streams_and_events() {
for (int i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
}
}
} // namespace transformer_engine } // namespace transformer_engine
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
...@@ -363,3 +377,37 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -363,3 +377,37 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
} }
void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETensor> B,
std::vector<NVTETensor> D, std::vector<NVTETensor> bias,
std::vector<NVTETensor> pre_gelu_out, bool transa, bool transb,
bool grad, std::vector<NVTETensor> workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
using namespace transformer_engine;
// Inits streams and events (once, globally)
std::call_once(init_flag, init_streams_and_events);
int num_stream_used = std::min(num_streams, static_cast<int>(A.size()));
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
}
for (size_t i = 0; i < A.size(); i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]);
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
}
// wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
}
}
...@@ -78,8 +78,46 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -78,8 +78,46 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool use_split_accumulator, int math_sm_count, int m_split, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter, int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The list of A matrices.
* \param[in] B The list of B matrices.
* \param[in,out] D List of output matrices.
* \param[in] bias List of bias tensors.
* \param[in,out] pre_gelu_out List of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace List of workspace tensors.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
*/
void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETensor> B,
std::vector<NVTETensor> D, std::vector<NVTETensor> bias,
std::vector<NVTETensor> pre_gelu_out, bool transa, bool transb,
bool grad, std::vector<NVTETensor> workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
/*! \namespace transformer_engine
*/
namespace transformer_engine {
constexpr int num_streams = 4;
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_H_ #endif // TRANSFORMER_ENGINE_GEMM_H_
...@@ -37,8 +37,7 @@ from transformer_engine.pytorch.module import Linear ...@@ -37,8 +37,7 @@ from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.attention import MultiheadAttention
......
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Python interface for GEMM extensions""" """Python interface for GEMM extensions"""
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union, List
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec from ..utils import assert_dim_for_fp8_exec
__all__ = ["gemm", "fp8_gemm"] __all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
def fp8_gemm( def fp8_gemm(
...@@ -64,8 +64,6 @@ def fp8_gemm( ...@@ -64,8 +64,6 @@ def fp8_gemm(
bias_dtype = TE_DType[bias_dtype] bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
if A.nelement() == 0 or B.nelement() == 0:
return out, gelu_input
args = ( args = (
A, A,
...@@ -216,8 +214,6 @@ def gemm( ...@@ -216,8 +214,6 @@ def gemm(
grad_bias = empty_tensor grad_bias = empty_tensor
bias = bias if use_bias else empty_tensor bias = bias if use_bias else empty_tensor
if A.nelement() == 0 or B.nelement() == 0:
return out, grad_bias, gelu_input
assert ( assert (
A.dtype == dtype and B.dtype == dtype A.dtype == dtype and B.dtype == dtype
...@@ -289,3 +285,158 @@ def gemm( ...@@ -289,3 +285,158 @@ def gemm(
_ = fn(*args) _ = fn(*args)
return out, grad_bias, gelu_input return out, grad_bias, gelu_input
def grouped_gemm(
A: List[torch.Tensor],
B: List[torch.Tensor],
out: List[torch.Tensor],
dtype: torch.dtype,
workspaces: List[torch.Tensor],
gelu: bool = False,
gelu_input: Optional[List[torch.Tensor]] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
"""Non FP8 Grouped GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
if gelu and not grad:
gelu_input = [torch.empty_like(o, dtype=dtype) for o in out]
elif not gelu:
gelu_input = empty_tensors
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
assert (
A[0].dtype == dtype and B[0].dtype == dtype
), f"Expected dtype={dtype}, but found A.dtype={A[0].dtype} and B.dtype={B[0].dtype}"
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out[0].dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = output_dtype
torch.ops.tex_ts.te_grouped_gemm_ts(
A,
empty_tensor,
0, # A_offset
input_dtype,
transa,
B,
empty_tensor,
0, # B_offset
input_dtype,
transb,
out,
0, # out_offset
empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias,
bias_dtype,
gelu_input, # gelu_input
grad,
workspaces,
workspaces[0].shape[0],
accumulate,
False, # use_split_accumulator
)
return out, grad_bias, gelu_input
def fp8_grouped_gemm(
A: List[torch.Tensor],
A_scale_inv: torch.Tensor,
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
B_scale_inv: torch.Tensor,
B_fp8_tensor_offset: int,
B_dtype: tex.DType,
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
accumulate: bool = False,
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor.
scale: [ ...A_scale... | ...B_scale... | ...out_scale...]
scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...]
amax: [ ...A_amax... | ...B_amax... | ...out_amax...]
"""
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_offset is not None
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a)
assert_dim_for_fp8_exec(b)
assert A[0].dtype == torch.uint8
assert B[0].dtype == torch.uint8
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [torch.empty_like(o, dtype=bias_dtype) for o in out]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
return out, gelu_input
...@@ -125,6 +125,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine ...@@ -125,6 +125,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter); bool gemm_producer, at::Tensor counter);
void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
/*************************************************************************************************** /***************************************************************************************************
* Transpose * Transpose
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "common/util/cuda_runtime.h"
#include "extensions.h" #include "extensions.h"
void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
...@@ -14,6 +15,12 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType ...@@ -14,6 +15,12 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType
at::Tensor workspace, size_t workspaceSize, bool accumulate, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) { bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine; using namespace transformer_engine;
if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) {
if (D.data_ptr() != nullptr && !accumulate) D.zero_();
if (bias.data_ptr() != nullptr) bias.zero_();
if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_();
return;
}
auto te_A = makeTransformerEngineTensor( auto te_A = makeTransformerEngineTensor(
A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type, A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type,
nullptr, nullptr, A_scale_inverse.data_ptr()); nullptr, nullptr, A_scale_inverse.data_ptr());
...@@ -77,3 +84,58 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine ...@@ -77,3 +84,58 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
accumulate, use_split_accumulator, math_sm_count, m_split, n_split, accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream()); gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream());
} }
void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < A.size(); i++) {
if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) {
if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_();
if (bias[i].data_ptr() != nullptr) bias[i].zero_();
if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_();
continue;
}
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
D[i].data_ptr(), {static_cast<size_t>(D[i].size(0)), static_cast<size_t>(D[i].size(1))},
D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
te_workspace.emplace_back(make_tensor(workspace[i % num_streams].data_ptr(), {workspaceSize},
DType::kByte, nullptr, nullptr, nullptr));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A, te_B, te_D, te_bias, te_pre_gelu_out, transa, transb, grad,
te_workspace, accumulate, use_split_accumulator, math_sm_count,
at::cuda::getCurrentCUDAStream());
}
...@@ -89,6 +89,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -89,6 +89,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>()); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>());
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think
m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM");
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV", "Fused Attention FP8/BF16/FP16 FWD with packed QKV",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
...@@ -271,6 +271,38 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te ...@@ -271,6 +271,38 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te
return D; return D;
} }
std::vector<at::Tensor> te_grouped_gemm_ts(
std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int sm_count = transformer_engine::cuda::sm_count();
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse,
B_offset, B_type_arg, transb_arg, D, D_offset, D_scale, D_type_arg, D_amax, bias,
bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale, const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
...@@ -336,6 +368,7 @@ TORCH_LIBRARY(tex_ts, m) { ...@@ -336,6 +368,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("qgelu_ts", &qgelu_ts); m.def("qgelu_ts", &qgelu_ts);
m.def("srelu_ts", &srelu_ts); m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts); m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""Module level PyTorch APIs""" """Module level PyTorch APIs"""
from .layernorm_linear import LayerNormLinear from .layernorm_linear import LayerNormLinear
from .linear import Linear from .linear import Linear
from .grouped_linear import GroupedLinear
from .layernorm_mlp import LayerNormMLP from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm from .layernorm import LayerNorm
from .rmsnorm import RMSNorm from .rmsnorm import RMSNorm
......
...@@ -41,9 +41,11 @@ __all__ = ["initialize_ub", "destroy_ub"] ...@@ -41,9 +41,11 @@ __all__ = ["initialize_ub", "destroy_ub"]
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True _2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 3 _NUM_MAX_UB_STREAMS = 3
_NUM_MAX_CUBLAS_STREAMS = 4
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -64,6 +66,17 @@ def get_workspace() -> torch.Tensor: ...@@ -64,6 +66,17 @@ def get_workspace() -> torch.Tensor:
return _cublas_workspace return _cublas_workspace
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace:
for _ in range(_NUM_MAX_CUBLAS_STREAMS):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
return _multi_stream_cublas_workspace
def initialize_ub( def initialize_ub(
shape: list, shape: list,
tp_group: dist_group_type, tp_group: dist_group_type,
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""GroupedLinear API"""
import os
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
import transformer_engine_torch as tex
from .base import (
get_multi_stream_cublas_workspace,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import (
divide,
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
init_method_constant,
requires_grad,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from ..cpp_extensions import (
cast_to_fp8,
fp8_cast_transpose_bgrad_fused,
fp8_cast_transpose_fused,
fp8_grouped_gemm,
grouped_gemm,
)
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
__all__ = ["GroupedLinear"]
"""
The offset for fp8_meta_index.
_GEMM_INPUT = 0
_GEMM_WEIGHT = num_gemms
_GEMM_OUTPUT = 2 * num_gemms
Must be properly set in GroupedLinear's initialization.
"""
_GEMM_INPUT = 0
_GEMM_WEIGHT = 0
_GEMM_OUTPUT = 0
_GRAD_OUTPUT = 0
def _pad_tensor(inp: torch.Tensor):
if inp.shape[0] % 16 == 0:
return inp
pad_len = (inp.shape[0] + 15) // 16 * 16 - inp.shape[0]
pad_tensor = torch.zeros(pad_len, inp.shape[1], dtype=inp.dtype, device=inp.device)
return torch.cat((inp, pad_tensor), dim=0)
class _GroupedLinear(torch.autograd.Function):
"""GroupedLinear semi-top level module
Calls custom cuda extensions.
"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
biases = weights_and_biases[2 * num_gemms :]
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits)
if fp8:
inputmats = [_pad_tensor(mat) for mat in inputmats]
for i in range(num_gemms):
assert_dim_for_fp8_exec(inputmats[i])
assert_dim_for_fp8_exec(weights[i])
# Cast input to expected dtype
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
inputmats_t = []
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
and weights[0].requires_grad
and not sequence_parallel
):
# FP8 input for forward, FP8 input transpose for backward wgrad
for i in range(num_gemms):
mat, mat_t = fp8_cast_transpose_fused(
inputmats_no_fp8[i],
fp8_meta["scaling_fwd"],
_GEMM_INPUT + i,
fp8_dtype_forward,
)
inputmats.append(mat)
inputmats_t.append(mat_t)
else:
# FP8 input for forward
inputmats = [
cast_to_fp8(
inputmats_no_fp8[i],
fp8_meta["scaling_fwd"],
_GEMM_INPUT + i,
fp8_dtype_forward,
)
for i in range(num_gemms)
]
else:
inputmats = inputmats_no_fp8
if fp8:
if _NVTE_DEBUG:
print("[GroupedLinear]: using FP8 forward")
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
# Use FP8 weights
if weights_fp8[0] is None:
weights_fp8 = weights
assert all(isinstance(w, Float8Tensor) for w in weights_fp8)
out_list = [
torch.empty(
[inputmats[i].size(0), weights_fp8[0].size(0)],
dtype=activation_dtype,
device=inputmats[i].device,
)
for i in range(num_gemms)
]
_ = fp8_grouped_gemm(
[w._data for w in weights_fp8],
fp8_meta["scaling_fwd"].scale_inv,
_GEMM_WEIGHT,
fp8_dtype_forward,
inputmats,
fp8_meta["scaling_fwd"].scale_inv,
_GEMM_INPUT,
fp8_dtype_forward,
out_list,
activation_dtype,
get_multi_stream_cublas_workspace(),
bias=biases,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
# unpad the output
out = torch.cat([o[: m_splits[i]] for i, o in enumerate(out_list)], dim=0)
else:
if _NVTE_DEBUG:
print("[GroupedLinear]: using non-FP8 forward")
# Cast for native AMP
weights = [cast_if_needed(w, activation_dtype) for w in weights]
biases = (
[cast_if_needed(bias, activation_dtype) for bias in biases] if use_bias else biases
)
if fp8_calibration:
for i in range(num_gemms):
# amax of input
amin, amax = inputmats[i].aminmax()
fp8_meta["scaling_fwd"].amax_history[0][_GEMM_INPUT + i] = torch.max(
-amin, amax
).float()
# amax of weight
amin, amax = weights[i].aminmax()
fp8_meta["scaling_fwd"].amax_history[0][_GEMM_WEIGHT + i] = torch.max(
-amin, amax
).float()
out = torch.empty(
[sum(m_splits), weights[0].size(0)],
dtype=activation_dtype,
device=inputmats[0].device,
)
_ = grouped_gemm(
weights,
inputmats,
torch.split(out, m_splits),
activation_dtype,
get_multi_stream_cublas_workspace(),
bias=biases,
use_bias=use_bias,
)
if is_grad_enabled:
saved_inputmats = [None] * num_gemms
saved_inputmats_t = [None] * num_gemms
if weights[0].requires_grad:
if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad:
if not inputmats_t:
saved_inputmats = inputmats
else:
saved_inputmats_t = inputmats_t
if cpu_offloading:
for t in saved_inputmats_t:
t.activation_offloading = True
else:
saved_inputmats = inputmats_no_fp8
if cpu_offloading:
if fuse_wgrad_accumulation:
for w in weights:
w.main_grad.weight_offloading = True
if fp8:
for w in weights_fp8:
if w is not None:
w.weight_offloading = True
for w in weights:
w.weight_offloading = True
for t in saved_inputmats:
if t is not None:
t.activation_offloading = True
ctx.save_for_backward(
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
*saved_inputmats,
*saved_inputmats_t,
*weights,
*weights_fp8,
*[
w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None
for w in weights
],
)
ctx.m_splits = m_splits
ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
fwd_scale_inverses,
*saved_tensors,
) = ctx.saved_tensors
inputmats = saved_tensors[: ctx.num_gemms]
inputmats_t = saved_tensors[ctx.num_gemms : 2 * ctx.num_gemms]
weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms]
weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms]
main_grads = saved_tensors[4 * ctx.num_gemms :]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in ctx.num_gemms:
w = torch.nn.Parameter(weights[i], False)
w.main_grad = main_grads[i]
weights[i] = w
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
# preprocess grad_output
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits
)
grad_output_c = [None] * ctx.num_gemms
grad_output_t = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
grad_output_mats = [_pad_tensor(mat) for mat in grad_output_mats]
if ctx.use_bias:
for i in range(ctx.num_gemms):
grad_biases[i], grad_output_c[i], grad_output_t[i] = (
fp8_cast_transpose_bgrad_fused(
grad_output_mats[i],
ctx.fp8_meta["scaling_bwd"],
_GRAD_OUTPUT + i,
fp8_dtype_backward,
)
)
else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
for i in range(ctx.num_gemms):
grad_output_c[i], grad_output_t[i] = fp8_cast_transpose_fused(
grad_output_mats[i],
ctx.fp8_meta["scaling_bwd"],
_GRAD_OUTPUT + i,
fp8_dtype_backward,
)
else:
for i in range(ctx.num_gemms):
grad_output_c[i] = cast_to_fp8(
grad_output_mats[i],
ctx.fp8_meta["scaling_bwd"],
_GRAD_OUTPUT + i,
fp8_dtype_backward,
)
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.requires_dgrad:
if ctx.fp8:
if _NVTE_DEBUG:
print("[GroupedLinear]: using FP8 backward")
dgrad_list = [
torch.empty(
(grad_output_c[i].size(0), weights_fp8[i].size(1)),
dtype=ctx.activation_dtype,
device=grad_output.device,
)
for i in range(ctx.num_gemms)
]
fp8_grouped_gemm(
[w.transpose_2d() for w in weights_fp8],
torch.cat(
[w._scale_inv for w in weights_fp8]
), # avoiding torch.cat requires another interface
0, # weight offset is 0 for the newly created _scale_inv
weights_fp8[0]._fp8_dtype,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
0,
fp8_dtype_backward,
dgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
# unpad the output
dgrad = torch.cat(
[d[: ctx.m_splits[i]] for i, d in enumerate(dgrad_list)], dim=0
)
else:
if _NVTE_DEBUG:
print("[GroupedLinear]: using non-FP8 backward")
dgrad = torch.empty(
(sum(ctx.m_splits), weights[0].size(1)),
dtype=ctx.activation_dtype,
device=grad_output.device,
)
grouped_gemm(
weights,
grad_output_mats,
torch.split(dgrad, ctx.m_splits),
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
layout="NN",
grad=True,
)
if weights[0].requires_grad:
if ctx.fuse_wgrad_accumulation:
wgrad_list = [w.main_grad for w in weights]
else:
wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights
]
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if inputmats_t[0] is None:
for i in range(ctx.num_gemms):
if isinstance(inputmats[i], Float8Tensor):
inputmats_t[i] = inputmats[i].transpose_2d()
else:
inputmats_t[i] = tex.fp8_transpose(
inputmats[i], fp8_dtype_backward
)
fp8_grouped_gemm(
[
inp._data if isinstance(inp, Float8Tensor) else inp
for inp in inputmats_t
],
fwd_scale_inverses,
_GEMM_INPUT,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT,
fp8_dtype_backward,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
grouped_gemm(
inputmats,
grad_output_mats,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
)
else:
# WGRAD
_, grad_biases, _ = grouped_gemm(
inputmats,
grad_output_mats,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
)
# Deallocate input tensor
clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t)
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
wgrad = None
return wgrad
wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None, # m_splits
None, # use_bias
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # fp8_meta
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # is_grad_enabled
*wgrad_list,
*([None] * ctx.num_gemms), # weights_fp8
*grad_biases,
)
class GroupedLinear(TransformerEngineBaseModule):
"""Applies linear transformations to the incoming data list
:math:`y_i = x_iA_i^T + b_i` in a grouped way.
Parameters
----------
num_gemms : int
number of GEMMs to be performed simutaneously.
in_features : int
size of each input sample.
out_features : int
size of each output sample.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initilizeing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this GroupedLinear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
self,
num_gemms: int,
in_features: int,
out_features: int,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
rng_tracker_name: Optional[str] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
device: Union[torch.device, str] = "cuda",
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag
self.ub_name = ub_name
assert (
not ub_overlap_rs and not ub_overlap_ag
), "GroupedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, num_gemms, 2 * num_gemms
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
for i in range(self.num_gemms):
# Construct weight parameter
self.register_parameter(
f"weight{i}",
torch.nn.Parameter(
torch.empty(
self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
),
),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=_GEMM_WEIGHT + i,
)
# Construct bias parameters if needed
if self.use_bias:
self.register_parameter(
f"bias{i}",
torch.nn.Parameter(
torch.empty(
self.out_features,
device=device,
dtype=params_dtype,
),
),
init_fn=init_method_constant(0.0),
)
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, f"bias{i}", bias)
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=(device == "meta"))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
else:
self.gemm_bias_unfused_add = False
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
if not defer_init:
# Set parallelism attributes for linear weights
for i in range(self.num_gemms):
set_tensor_model_parallel_attributes(
tensor=getattr(self, f"weight{i}"),
is_parallel=True,
dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
# Set parallelism attributes for linear biases
if self.use_bias:
for i in range(self.num_gemms):
if self.parallel_mode == "row":
setattr(
getattr(self, f"bias{i}"),
"sequence_parallel",
self.sequence_parallel,
)
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1)
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
m_splits: List[int],
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply the linear transformation to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
m_splits : List[int]
List of integers representing the split of the input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
assert not isinstance(
inp, Float8Tensor
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp:
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
weight_tensors = [
w.from_float8() if isinstance(w, Float8Tensor) else w for w in weight_tensors
]
weight_tensors_fp8 = [None] * self.num_gemms
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
for i in range(self.num_gemms):
if isinstance(weight_tensors[i], Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
weight_tensors[i].transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
)
else:
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
weight_tensors_fp8[i] = self.get_fp8_workspace(
tensor=weight_tensors[i],
fp8_meta_forward=True,
fp8_meta_index=_GEMM_WEIGHT + i,
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
linear_fn = _GroupedLinear.apply
args = []
else:
linear_fn = _GroupedLinear.forward
args = [None]
args += (
inp,
m_splits,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
*weight_tensors,
*weight_tensors_fp8,
*bias_tensors,
)
out = linear_fn(*args)
if self.gemm_bias_unfused_add:
out = [o + cast_if_needed(b, self.activation_dtype) for o, b in zip(out, bias_tensors)]
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment