"vscode:/vscode.git/clone" did not exist on "3c18d5644a7224e2151cc657d42f9fa2e61b3efe"
Unverified Commit a4e95e86 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common/PyTorch] Grouped GEMM via multi-stream cuBLAS (#853)



* GroupedGEMM via multi-stream cublas

* fix A/B is nullptr while D is not nullptr

* add fp8 grouped gemm

* register with TorchScript

* add the GroupedLinear layer

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarJiang Shao <jiangs@nvidia.com>
Co-authored-by: default avatarQi Zhang <qizhang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 85aeb903
...@@ -31,7 +31,8 @@ disable=too-many-locals, ...@@ -31,7 +31,8 @@ disable=too-many-locals,
global-variable-not-assigned, global-variable-not-assigned,
redefined-argument-from-local, redefined-argument-from-local,
line-too-long, line-too-long,
too-many-return-statements too-many-return-statements,
too-many-nested-blocks
[TYPECHECK] [TYPECHECK]
ignored-modules=torch ignored-modules=torch
......
...@@ -24,6 +24,7 @@ from transformer_engine.pytorch import ( ...@@ -24,6 +24,7 @@ from transformer_engine.pytorch import (
LayerNormLinear, LayerNormLinear,
LayerNormMLP, LayerNormMLP,
Linear, Linear,
GroupedLinear,
MultiheadAttention, MultiheadAttention,
RMSNorm, RMSNorm,
TransformerLayer, TransformerLayer,
...@@ -31,7 +32,9 @@ from transformer_engine.pytorch import ( ...@@ -31,7 +32,9 @@ from transformer_engine.pytorch import (
InferenceParams, InferenceParams,
) )
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
import transformer_engine_torch as tex
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -1211,6 +1214,99 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): ...@@ -1211,6 +1214,99 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 1,))).values.tolist()
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear):
m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist())
else:
out = torch.cat(
[
block[i](inp)
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
]
)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
outputs = _test_grouped_linear_accuracy(grouped_linear, num_gemms, bs, dtype, config, fp8)
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states() reset_rng_states()
...@@ -1563,3 +1659,157 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1563,3 +1659,157 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
# Check if the fully generated output matches the one generated incrementally # Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype]) assert_allclose(full_output, incremental_output, atol[dtype])
@pytest.mark.parametrize(
"shape",
[
(1, 127, 128, 512),
(8, 15, 128, 512),
(8, 1027, 128, 512),
(16, 10027, 128, 512),
],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
torch.manual_seed(0)
z, m, k, n = shape
dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
assert m_splits.sum() == m and len(m_splits) == z
m_splits = m_splits.tolist()
if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output
grad = False
elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # dgrad
grad = True
else: # layout == "NT"
A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
grad = True
out_ref = [o.clone() for o in out]
for i in range(z):
gemm(
A[i],
B[i],
dtype,
get_workspace(),
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)
grouped_gemm(
A,
B,
out,
dtype,
get_multi_stream_cublas_workspace(),
grad=grad,
accumulate=accumulate,
layout=layout,
)
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize(
"shape",
[
(1, 128, 128, 512),
(8, 1024, 128, 512),
(16, 4096, 128, 512),
],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
z, m, k, n = shape
m_splits = m // z
dtype = torch.bfloat16
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output
out_ref = [o.clone() for o in out]
# fp8 should be robust enough to this fake scale
scale = 1 + torch.rand(z * 3, dtype=torch.float32, device="cuda")
scale_inv = 1 / scale
amax = torch.zeros(1024, z * 3, dtype=torch.float32, device="cuda")
A_fp8 = [
torch.ops.tex_ts.cast_to_fp8_ts(
A[i],
scale,
amax,
scale_inv,
i, # fp8 meta tensor index
tex.DType.kFloat8E4M3,
)
for i in range(z)
]
B_fp8 = [
torch.ops.tex_ts.cast_to_fp8_ts(
B[i],
scale,
amax,
scale_inv,
z + i, # fp8 meta tensor index
fp8_dtype,
)
for i in range(z)
]
fp8_grouped_gemm(
A_fp8,
scale_inv,
0, # A_offset
tex.DType.kFloat8E4M3,
B_fp8,
scale_inv,
z, # B_offset
fp8_dtype,
out,
dtype,
get_multi_stream_cublas_workspace(),
accumulate=accumulate,
)
# baseline
for i in range(z):
fp8_gemm(
A_fp8[i],
scale_inv,
i,
tex.DType.kFloat8E4M3,
B_fp8[i],
scale_inv,
z + i,
fp8_dtype,
dtype,
get_workspace(),
out=out_ref[i],
accumulate=accumulate,
)
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <cstdint> #include <cstdint>
#include <mutex>
#include "../common.h" #include "../common.h"
#include "../util/logging.h" #include "../util/logging.h"
...@@ -208,6 +209,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -208,6 +209,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
if (counter != nullptr) { if (counter != nullptr) {
if (m_split == 0) m_split = 1; if (m_split == 0) m_split = 1;
...@@ -275,6 +277,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -275,6 +277,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
} }
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];
// Warning: only call once per device!
static void init_streams_and_events() {
for (int i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
}
}
} // namespace transformer_engine } // namespace transformer_engine
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
...@@ -363,3 +377,37 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -363,3 +377,37 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
} }
void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETensor> B,
std::vector<NVTETensor> D, std::vector<NVTETensor> bias,
std::vector<NVTETensor> pre_gelu_out, bool transa, bool transb,
bool grad, std::vector<NVTETensor> workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
using namespace transformer_engine;
// Inits streams and events (once, globally)
std::call_once(init_flag, init_streams_and_events);
int num_stream_used = std::min(num_streams, static_cast<int>(A.size()));
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
}
for (size_t i = 0; i < A.size(); i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]);
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
}
// wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
}
}
...@@ -78,8 +78,46 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -78,8 +78,46 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool use_split_accumulator, int math_sm_count, int m_split, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter, int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The list of A matrices.
* \param[in] B The list of B matrices.
* \param[in,out] D List of output matrices.
* \param[in] bias List of bias tensors.
* \param[in,out] pre_gelu_out List of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace List of workspace tensors.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
*/
void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETensor> B,
std::vector<NVTETensor> D, std::vector<NVTETensor> bias,
std::vector<NVTETensor> pre_gelu_out, bool transa, bool transb,
bool grad, std::vector<NVTETensor> workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
/*! \namespace transformer_engine
*/
namespace transformer_engine {
constexpr int num_streams = 4;
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_H_ #endif // TRANSFORMER_ENGINE_GEMM_H_
...@@ -37,8 +37,7 @@ from transformer_engine.pytorch.module import Linear ...@@ -37,8 +37,7 @@ from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.attention import MultiheadAttention
......
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Python interface for GEMM extensions""" """Python interface for GEMM extensions"""
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union, List
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec from ..utils import assert_dim_for_fp8_exec
__all__ = ["gemm", "fp8_gemm"] __all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
def fp8_gemm( def fp8_gemm(
...@@ -64,8 +64,6 @@ def fp8_gemm( ...@@ -64,8 +64,6 @@ def fp8_gemm(
bias_dtype = TE_DType[bias_dtype] bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
if A.nelement() == 0 or B.nelement() == 0:
return out, gelu_input
args = ( args = (
A, A,
...@@ -216,8 +214,6 @@ def gemm( ...@@ -216,8 +214,6 @@ def gemm(
grad_bias = empty_tensor grad_bias = empty_tensor
bias = bias if use_bias else empty_tensor bias = bias if use_bias else empty_tensor
if A.nelement() == 0 or B.nelement() == 0:
return out, grad_bias, gelu_input
assert ( assert (
A.dtype == dtype and B.dtype == dtype A.dtype == dtype and B.dtype == dtype
...@@ -289,3 +285,158 @@ def gemm( ...@@ -289,3 +285,158 @@ def gemm(
_ = fn(*args) _ = fn(*args)
return out, grad_bias, gelu_input return out, grad_bias, gelu_input
def grouped_gemm(
A: List[torch.Tensor],
B: List[torch.Tensor],
out: List[torch.Tensor],
dtype: torch.dtype,
workspaces: List[torch.Tensor],
gelu: bool = False,
gelu_input: Optional[List[torch.Tensor]] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
"""Non FP8 Grouped GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
if gelu and not grad:
gelu_input = [torch.empty_like(o, dtype=dtype) for o in out]
elif not gelu:
gelu_input = empty_tensors
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
assert (
A[0].dtype == dtype and B[0].dtype == dtype
), f"Expected dtype={dtype}, but found A.dtype={A[0].dtype} and B.dtype={B[0].dtype}"
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out[0].dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = output_dtype
torch.ops.tex_ts.te_grouped_gemm_ts(
A,
empty_tensor,
0, # A_offset
input_dtype,
transa,
B,
empty_tensor,
0, # B_offset
input_dtype,
transb,
out,
0, # out_offset
empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias,
bias_dtype,
gelu_input, # gelu_input
grad,
workspaces,
workspaces[0].shape[0],
accumulate,
False, # use_split_accumulator
)
return out, grad_bias, gelu_input
def fp8_grouped_gemm(
A: List[torch.Tensor],
A_scale_inv: torch.Tensor,
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
B_scale_inv: torch.Tensor,
B_fp8_tensor_offset: int,
B_dtype: tex.DType,
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
accumulate: bool = False,
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor.
scale: [ ...A_scale... | ...B_scale... | ...out_scale...]
scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...]
amax: [ ...A_amax... | ...B_amax... | ...out_amax...]
"""
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_offset is not None
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a)
assert_dim_for_fp8_exec(b)
assert A[0].dtype == torch.uint8
assert B[0].dtype == torch.uint8
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [torch.empty_like(o, dtype=bias_dtype) for o in out]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
return out, gelu_input
...@@ -125,6 +125,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine ...@@ -125,6 +125,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter); bool gemm_producer, at::Tensor counter);
void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
/*************************************************************************************************** /***************************************************************************************************
* Transpose * Transpose
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "common/util/cuda_runtime.h"
#include "extensions.h" #include "extensions.h"
void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
...@@ -14,6 +15,12 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType ...@@ -14,6 +15,12 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType
at::Tensor workspace, size_t workspaceSize, bool accumulate, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) { bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine; using namespace transformer_engine;
if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) {
if (D.data_ptr() != nullptr && !accumulate) D.zero_();
if (bias.data_ptr() != nullptr) bias.zero_();
if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_();
return;
}
auto te_A = makeTransformerEngineTensor( auto te_A = makeTransformerEngineTensor(
A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type, A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type,
nullptr, nullptr, A_scale_inverse.data_ptr()); nullptr, nullptr, A_scale_inverse.data_ptr());
...@@ -77,3 +84,58 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine ...@@ -77,3 +84,58 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
accumulate, use_split_accumulator, math_sm_count, m_split, n_split, accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream()); gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream());
} }
void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < A.size(); i++) {
if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) {
if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_();
if (bias[i].data_ptr() != nullptr) bias[i].zero_();
if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_();
continue;
}
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
D[i].data_ptr(), {static_cast<size_t>(D[i].size(0)), static_cast<size_t>(D[i].size(1))},
D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
te_workspace.emplace_back(make_tensor(workspace[i % num_streams].data_ptr(), {workspaceSize},
DType::kByte, nullptr, nullptr, nullptr));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A, te_B, te_D, te_bias, te_pre_gelu_out, transa, transb, grad,
te_workspace, accumulate, use_split_accumulator, math_sm_count,
at::cuda::getCurrentCUDAStream());
}
...@@ -89,6 +89,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -89,6 +89,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>()); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>());
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think
m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM");
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV", "Fused Attention FP8/BF16/FP16 FWD with packed QKV",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
...@@ -271,6 +271,38 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te ...@@ -271,6 +271,38 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te
return D; return D;
} }
std::vector<at::Tensor> te_grouped_gemm_ts(
std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int sm_count = transformer_engine::cuda::sm_count();
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse,
B_offset, B_type_arg, transb_arg, D, D_offset, D_scale, D_type_arg, D_amax, bias,
bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale, const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
...@@ -336,6 +368,7 @@ TORCH_LIBRARY(tex_ts, m) { ...@@ -336,6 +368,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("qgelu_ts", &qgelu_ts); m.def("qgelu_ts", &qgelu_ts);
m.def("srelu_ts", &srelu_ts); m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts); m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""Module level PyTorch APIs""" """Module level PyTorch APIs"""
from .layernorm_linear import LayerNormLinear from .layernorm_linear import LayerNormLinear
from .linear import Linear from .linear import Linear
from .grouped_linear import GroupedLinear
from .layernorm_mlp import LayerNormMLP from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm from .layernorm import LayerNorm
from .rmsnorm import RMSNorm from .rmsnorm import RMSNorm
......
...@@ -41,9 +41,11 @@ __all__ = ["initialize_ub", "destroy_ub"] ...@@ -41,9 +41,11 @@ __all__ = ["initialize_ub", "destroy_ub"]
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True _2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 3 _NUM_MAX_UB_STREAMS = 3
_NUM_MAX_CUBLAS_STREAMS = 4
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -64,6 +66,17 @@ def get_workspace() -> torch.Tensor: ...@@ -64,6 +66,17 @@ def get_workspace() -> torch.Tensor:
return _cublas_workspace return _cublas_workspace
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace:
for _ in range(_NUM_MAX_CUBLAS_STREAMS):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
return _multi_stream_cublas_workspace
def initialize_ub( def initialize_ub(
shape: list, shape: list,
tp_group: dist_group_type, tp_group: dist_group_type,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment