Unverified Commit a4e95e86 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common/PyTorch] Grouped GEMM via multi-stream cuBLAS (#853)



* GroupedGEMM via multi-stream cublas

* fix A/B is nullptr while D is not nullptr

* add fp8 grouped gemm

* register with TorchScript

* add the GroupedLinear layer

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarJiang Shao <jiangs@nvidia.com>
Co-authored-by: default avatarQi Zhang <qizhang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 85aeb903
......@@ -31,7 +31,8 @@ disable=too-many-locals,
global-variable-not-assigned,
redefined-argument-from-local,
line-too-long,
too-many-return-statements
too-many-return-statements,
too-many-nested-blocks
[TYPECHECK]
ignored-modules=torch
......
......@@ -24,6 +24,7 @@ from transformer_engine.pytorch import (
LayerNormLinear,
LayerNormMLP,
Linear,
GroupedLinear,
MultiheadAttention,
RMSNorm,
TransformerLayer,
......@@ -31,7 +32,9 @@ from transformer_engine.pytorch import (
InferenceParams,
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
import transformer_engine_torch as tex
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -1211,6 +1214,99 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 1,))).values.tolist()
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear):
m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist())
else:
out = torch.cat(
[
block[i](inp)
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
]
)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
outputs = _test_grouped_linear_accuracy(grouped_linear, num_gemms, bs, dtype, config, fp8)
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states()
......@@ -1563,3 +1659,157 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
# Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype])
@pytest.mark.parametrize(
"shape",
[
(1, 127, 128, 512),
(8, 15, 128, 512),
(8, 1027, 128, 512),
(16, 10027, 128, 512),
],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
torch.manual_seed(0)
z, m, k, n = shape
dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
assert m_splits.sum() == m and len(m_splits) == z
m_splits = m_splits.tolist()
if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output
grad = False
elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # dgrad
grad = True
else: # layout == "NT"
A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
grad = True
out_ref = [o.clone() for o in out]
for i in range(z):
gemm(
A[i],
B[i],
dtype,
get_workspace(),
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)
grouped_gemm(
A,
B,
out,
dtype,
get_multi_stream_cublas_workspace(),
grad=grad,
accumulate=accumulate,
layout=layout,
)
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize(
"shape",
[
(1, 128, 128, 512),
(8, 1024, 128, 512),
(16, 4096, 128, 512),
],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
z, m, k, n = shape
m_splits = m // z
dtype = torch.bfloat16
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output
out_ref = [o.clone() for o in out]
# fp8 should be robust enough to this fake scale
scale = 1 + torch.rand(z * 3, dtype=torch.float32, device="cuda")
scale_inv = 1 / scale
amax = torch.zeros(1024, z * 3, dtype=torch.float32, device="cuda")
A_fp8 = [
torch.ops.tex_ts.cast_to_fp8_ts(
A[i],
scale,
amax,
scale_inv,
i, # fp8 meta tensor index
tex.DType.kFloat8E4M3,
)
for i in range(z)
]
B_fp8 = [
torch.ops.tex_ts.cast_to_fp8_ts(
B[i],
scale,
amax,
scale_inv,
z + i, # fp8 meta tensor index
fp8_dtype,
)
for i in range(z)
]
fp8_grouped_gemm(
A_fp8,
scale_inv,
0, # A_offset
tex.DType.kFloat8E4M3,
B_fp8,
scale_inv,
z, # B_offset
fp8_dtype,
out,
dtype,
get_multi_stream_cublas_workspace(),
accumulate=accumulate,
)
# baseline
for i in range(z):
fp8_gemm(
A_fp8[i],
scale_inv,
i,
tex.DType.kFloat8E4M3,
B_fp8[i],
scale_inv,
z + i,
fp8_dtype,
dtype,
get_workspace(),
out=out_ref[i],
accumulate=accumulate,
)
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
......@@ -11,6 +11,7 @@
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
#include <mutex>
#include "../common.h"
#include "../util/logging.h"
......@@ -208,6 +209,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
if (counter != nullptr) {
if (m_split == 0) m_split = 1;
......@@ -275,6 +277,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
}
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];
// Warning: only call once per device!
static void init_streams_and_events() {
for (int i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
}
}
} // namespace transformer_engine
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
......@@ -363,3 +377,37 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
}
void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETensor> B,
std::vector<NVTETensor> D, std::vector<NVTETensor> bias,
std::vector<NVTETensor> pre_gelu_out, bool transa, bool transb,
bool grad, std::vector<NVTETensor> workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
using namespace transformer_engine;
// Inits streams and events (once, globally)
std::call_once(init_flag, init_streams_and_events);
int num_stream_used = std::min(num_streams, static_cast<int>(A.size()));
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
}
for (size_t i = 0; i < A.size(); i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]);
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
}
// wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
}
}
......@@ -78,8 +78,46 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The list of A matrices.
* \param[in] B The list of B matrices.
* \param[in,out] D List of output matrices.
* \param[in] bias List of bias tensors.
* \param[in,out] pre_gelu_out List of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace List of workspace tensors.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
*/
void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETensor> B,
std::vector<NVTETensor> D, std::vector<NVTETensor> bias,
std::vector<NVTETensor> pre_gelu_out, bool transa, bool transb,
bool grad, std::vector<NVTETensor> workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
/*! \namespace transformer_engine
*/
namespace transformer_engine {
constexpr int num_streams = 4;
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_H_
......@@ -37,8 +37,7 @@ from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention
......
......@@ -3,14 +3,14 @@
# See LICENSE for license information.
"""Python interface for GEMM extensions"""
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, List
import torch
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec
__all__ = ["gemm", "fp8_gemm"]
__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
def fp8_gemm(
......@@ -64,8 +64,6 @@ def fp8_gemm(
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
if A.nelement() == 0 or B.nelement() == 0:
return out, gelu_input
args = (
A,
......@@ -216,8 +214,6 @@ def gemm(
grad_bias = empty_tensor
bias = bias if use_bias else empty_tensor
if A.nelement() == 0 or B.nelement() == 0:
return out, grad_bias, gelu_input
assert (
A.dtype == dtype and B.dtype == dtype
......@@ -289,3 +285,158 @@ def gemm(
_ = fn(*args)
return out, grad_bias, gelu_input
def grouped_gemm(
A: List[torch.Tensor],
B: List[torch.Tensor],
out: List[torch.Tensor],
dtype: torch.dtype,
workspaces: List[torch.Tensor],
gelu: bool = False,
gelu_input: Optional[List[torch.Tensor]] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
"""Non FP8 Grouped GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
if gelu and not grad:
gelu_input = [torch.empty_like(o, dtype=dtype) for o in out]
elif not gelu:
gelu_input = empty_tensors
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
assert (
A[0].dtype == dtype and B[0].dtype == dtype
), f"Expected dtype={dtype}, but found A.dtype={A[0].dtype} and B.dtype={B[0].dtype}"
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out[0].dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = output_dtype
torch.ops.tex_ts.te_grouped_gemm_ts(
A,
empty_tensor,
0, # A_offset
input_dtype,
transa,
B,
empty_tensor,
0, # B_offset
input_dtype,
transb,
out,
0, # out_offset
empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias,
bias_dtype,
gelu_input, # gelu_input
grad,
workspaces,
workspaces[0].shape[0],
accumulate,
False, # use_split_accumulator
)
return out, grad_bias, gelu_input
def fp8_grouped_gemm(
A: List[torch.Tensor],
A_scale_inv: torch.Tensor,
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
B_scale_inv: torch.Tensor,
B_fp8_tensor_offset: int,
B_dtype: tex.DType,
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
accumulate: bool = False,
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor.
scale: [ ...A_scale... | ...B_scale... | ...out_scale...]
scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...]
amax: [ ...A_amax... | ...B_amax... | ...out_amax...]
"""
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_offset is not None
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a)
assert_dim_for_fp8_exec(b)
assert A[0].dtype == torch.uint8
assert B[0].dtype == torch.uint8
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [torch.empty_like(o, dtype=bias_dtype) for o in out]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
return out, gelu_input
......@@ -125,6 +125,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter);
void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
/***************************************************************************************************
* Transpose
**************************************************************************************************/
......
......@@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "common/util/cuda_runtime.h"
#include "extensions.h"
void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
......@@ -14,6 +15,12 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) {
if (D.data_ptr() != nullptr && !accumulate) D.zero_();
if (bias.data_ptr() != nullptr) bias.zero_();
if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_();
return;
}
auto te_A = makeTransformerEngineTensor(
A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type,
nullptr, nullptr, A_scale_inverse.data_ptr());
......@@ -77,3 +84,58 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream());
}
void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < A.size(); i++) {
if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) {
if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_();
if (bias[i].data_ptr() != nullptr) bias[i].zero_();
if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_();
continue;
}
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
D[i].data_ptr(), {static_cast<size_t>(D[i].size(0)), static_cast<size_t>(D[i].size(1))},
D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
te_workspace.emplace_back(make_tensor(workspace[i % num_streams].data_ptr(), {workspaceSize},
DType::kByte, nullptr, nullptr, nullptr));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A, te_B, te_D, te_bias, te_pre_gelu_out, transa, transb, grad,
te_workspace, accumulate, use_split_accumulator, math_sm_count,
at::cuda::getCurrentCUDAStream());
}
......@@ -89,6 +89,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>());
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>());
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think
m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM");
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV",
py::call_guard<py::gil_scoped_release>());
......
......@@ -271,6 +271,38 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te
return D;
}
std::vector<at::Tensor> te_grouped_gemm_ts(
std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int sm_count = transformer_engine::cuda::sm_count();
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse,
B_offset, B_type_arg, transb_arg, D, D_offset, D_scale, D_type_arg, D_amax, bias,
bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
......@@ -336,6 +368,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("qgelu_ts", &qgelu_ts);
m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
......
......@@ -5,6 +5,7 @@
"""Module level PyTorch APIs"""
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .grouped_linear import GroupedLinear
from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm
from .rmsnorm import RMSNorm
......
......@@ -41,9 +41,11 @@ __all__ = ["initialize_ub", "destroy_ub"]
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_NUM_MAX_CUBLAS_STREAMS = 4
layers_atomic_ring_exchange = []
......@@ -64,6 +66,17 @@ def get_workspace() -> torch.Tensor:
return _cublas_workspace
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace:
for _ in range(_NUM_MAX_CUBLAS_STREAMS):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
return _multi_stream_cublas_workspace
def initialize_ub(
shape: list,
tp_group: dist_group_type,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment