"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd"
Commit 9d0f1c9b authored by yuguo's avatar yuguo
Browse files

[DCU] add batchgemm test

parent e8f92b93
...@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test ...@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Installation script.""" """Installation script."""
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v # NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# VTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel # NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
import os import os
import sys import sys
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections import OrderedDict
import math
import os
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
fp8_autocast,
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
attention_mask_func,
is_bf16_compatible,
)
from transformer_engine.pytorch import (
DotProductAttention,
LayerNormLinear,
LayerNormMLP,
Linear,
GroupedLinear,
BatchedLinear,
MultiheadAttention,
RMSNorm,
TransformerLayer,
LayerNorm,
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm, batchgemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
if torch_version() >= (2, 7, 0):
torch._dynamo.config.recompile_limit = 16
else:
torch._dynamo.config.cache_size_limit = 16
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
}
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"]
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
]
def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol)
if rtol is not None:
tols["rtol"] = rtol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def _test_batched_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
assert config.seq_len % num_gemms == 0
m_splits = torch.tensor([config.seq_len // num_gemms for i in range(num_gemms)])
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, BatchedLinear):
m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist())
else:
out = torch.cat(
[
block[i](inp)
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
]
)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [4, 8])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_batched_linear_accuracy(
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
parallel_mode=None,
):
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for batched linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for batched linear.")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
batched_linear = BatchedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms // batch_num):
weight = getattr(batched_linear, f"weight{i}").clone()
# bias = getattr(batched_linear, f"bias{i}").clone()
if fuse_wgrad_accumulation:
weight_i = getattr(batched_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
for j in range(batch_num):
sequential_linear[i * batch_num + j].weight = Parameter(weight[weight.shape[0] // batch_num * j : weight.shape[0] // batch_num * (j + 1)].clone())
# sequential_linear[i * batch_num + j].bias = Parameter(bias[bias.shape[0] // batch_num * j : bias.shape[0] // batch_num * (j + 1)].clone())
if fuse_wgrad_accumulation:
sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone()
outputs_ref = _test_batched_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
outputs = _test_batched_linear_accuracy(
batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3)
if __name__ == "__main__":
test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True)
...@@ -778,6 +778,9 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -778,6 +778,9 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias); const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out); Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace); Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
}
int m, n, k; int m, n, k;
if (!transa && transb) { if (!transa && transb) {
......
...@@ -18,7 +18,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase ...@@ -18,7 +18,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
"general_grouped_gemm", "general_grouped_gemm",
"general_batched_gemm", "batchgemm",
] ]
...@@ -226,84 +226,78 @@ def general_grouped_gemm( ...@@ -226,84 +226,78 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
def general_batched_gemm( def batchgemm(
A: List[torch.Tensor], A: List[torch.Tensor],
B: List[torch.Tensor], B: List[torch.Tensor],
out: List[torch.Tensor], out: List[torch.Tensor],
out_dtype: torch.dtype, dtype: torch.dtype,
workspaces: List[torch.Tensor], workspaces: List[torch.Tensor],
layout: str = "TN",
m_splits: Optional[List[int]] = None,
gelu: bool = False, gelu: bool = False,
grad=False, gelu_input: Optional[List[torch.Tensor]] = None,
grad: bool = False,
accumulate: bool = False, accumulate: bool = False,
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None, bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False, use_bias: bool = False,
use_split_accumulator: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]:
D_dtype: Optional[tex.DType] = None, """Non FP8 batch GEMM."""
single_output=False,
) -> Tuple[List[torch.Tensor], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
"""
num_gemms = len(A)
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
# assert [a.is_contiguous() for a in A] if gelu and not grad:
# assert [b.is_contiguous() for b in B] gelu_input = [
torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out
if isinstance(A[0], Float8TensorBase): ]
for a, b in zip(A, B): elif not gelu:
assert_dim_for_fp8_exec(a._data)
assert_dim_for_fp8_exec(b._data)
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
# Use bfloat16 as default bias_dtype
gelu_input = empty_tensors gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
if grad and use_bias: if grad and use_bias:
grad_bias = [ grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms) torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
] ]
else: else:
grad_bias = empty_tensors grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors bias = bias if use_bias else empty_tensors
assert (
A[0].dtype == dtype and B[0].dtype == dtype
), f"Expected dtype={dtype}, but found A.dtype={A[0].dtype} and B.dtype={B[0].dtype}"
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out[0].dtype]
if use_bias: if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype] bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else: else:
bias_dtype = TE_DType[torch.bfloat16] bias_dtype = output_dtype
tex.te_batchgemm_ts(
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
] # this should differ with respect to single output
bias = tex.te_general_batched_gemm(
A, A,
empty_tensor,
0, # A_offset
input_dtype,
transa, transa,
B, B,
empty_tensor,
0, # B_offset
input_dtype,
transb, transb,
out, out,
out_dtype, 0, # out_offset
m_splits, empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias, grad_bias if grad else bias,
bias_dtype, bias_dtype,
single_output, gelu_input, # gelu_input
gelu_input, # this is pre_gelu_out grad,
grad, # grad
workspaces, workspaces,
workspaces[0].shape[0], workspaces[0].shape[0],
accumulate, accumulate,
use_split_accumulator, False, # use_split_accumulator
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
) )
return out, bias, gelu_input return out, grad_bias, gelu_input
...@@ -102,13 +102,23 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -102,13 +102,23 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool use_split_accumulator, int math_sm_count); bool use_split_accumulator, int math_sm_count);
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
std::optional<std::vector<at::Tensor>> te_general_batched_gemm( void te_batchgemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb, transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type, at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias, bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out, transformer_engine::DType D_type, at::Tensor D_amax,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count); bool use_split_accumulator, int math_sm_count);
std::vector<at::Tensor> te_batchgemm_ts(
std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator);
#endif #endif
/*************************************************************************************************** /***************************************************************************************************
......
...@@ -424,123 +424,104 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -424,123 +424,104 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
} }
#ifdef USE_ROCM #ifdef USE_ROCM
std::optional<std::vector<at::Tensor>> te_general_batched_gemm( void te_batchgemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb, transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type, at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias, bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out, transformer_engine::DType D_type, at::Tensor D_amax,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) { bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine; using namespace transformer_engine;
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector, std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
te_pre_gelu_out_vector, te_workspace_vector; std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
std::vector<TensorWrapper> wrappers; auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
std::vector<at::Tensor> D_vectors; transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
auto none = py::none(); tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
std::vector<size_t> single_output_begins; return tensor_wrappers.back().data();
std::vector<size_t> single_output_ends; };
int slicing_dim; for (size_t i = 0; i < A.size(); i++) {
if (single_output && D == std::nullopt) { if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) {
NVTE_ERROR("not implemented, D should be allocated for single output case."); if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_();
if (bias[i].data_ptr() != nullptr) bias[i].zero_();
if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_();
continue;
} }
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
D[i].data_ptr(), {static_cast<size_t>(D[i].size(0)), static_cast<size_t>(D[i].size(1))},
D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));
void* output_data_ptr; const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
if (single_output) { ? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
output_data_ptr = (*D)[0].data_ptr(); : std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
}
for (size_t i = 0; i < workspace.size(); i++) {
te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
nullptr, nullptr, nullptr));
} }
for (size_t i = 0; i < A.size(); i++) { nvte_multi_stream_cublas_batchgemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
auto te_A = makeTransformerEngineTensor(A[i], none); te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
auto te_B = makeTransformerEngineTensor(B[i], none); te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
// if there is single output transformer_engine::DType reverse_map_dtype(int64_t dtype) {
at::Tensor out_tensor; if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
auto size_t_shape = return static_cast<transformer_engine::DType>(dtype);
pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb);
bool D_numel_is_zero = false;
std::vector<int64_t> D_shape;
for (size_t t : size_t_shape) {
D_shape.push_back(t);
if (t == 0) {
D_numel_is_zero = true;
}
}
auto dtype = GetATenDType(D_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
if (single_output) {
if (output_data_ptr == nullptr) {
out_tensor = at::empty(D_shape, opts);
} else {
// We need to check !D_numel_is_zero because if the final input portion has zero elements,
// output_data_ptr would point beyond the allocated memory of D. This would cause
// at::from_blob to fail as it would reference memory not allocated by CUDA.
if (!D_numel_is_zero) {
out_tensor = at::from_blob(output_data_ptr, D_shape, opts);
}
}
char* char_ptr = reinterpret_cast<char*>(output_data_ptr);
char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size();
output_data_ptr = reinterpret_cast<void*>(char_ptr);
D_vectors.emplace_back(out_tensor);
} else {
if (D == std::nullopt) {
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
out_tensor = at::empty(D_shape, opts);
D_vectors.emplace_back(out_tensor);
} else { } else {
out_tensor = (*D)[i]; NVTE_ERROR("Type not supported.");
}
}
if (te_A.numel() == 0 || te_B.numel() == 0) {
if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_();
if (bias[i].numel() != 0 && grad) {
bias[i].zero_();
} }
if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_(); }
continue;
}
auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]);
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]);
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr std::vector<at::Tensor> te_batchgemm_ts(
? std::vector<size_t>{static_cast<size_t>(te_pre_gelu_out.size(0))} std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
: std::vector<size_t>{static_cast<size_t>(te_pre_gelu_out.size(0)), int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
static_cast<size_t>(te_pre_gelu_out.size(1))}; int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
DType gelu_type = bias_type; // Set an external SM Margin to all the GEMMs.
te_pre_gelu_out = // This comes in handy when DP is overlapped with GEMMs
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type);
te_A_vector.emplace_back(te_A.data()); const int sm_count = transformer_engine::cuda::sm_count();
te_B_vector.emplace_back(te_B.data()); int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
te_D_vector.emplace_back(te_D.data());
te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());
wrappers.emplace_back(std::move(te_A)); te_batchgemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse,
wrappers.emplace_back(std::move(te_B)); B_offset, B_type_arg, transb_arg, D, D_offset, D_scale, D_type_arg, D_amax, bias,
wrappers.emplace_back(std::move(te_D)); bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
wrappers.emplace_back(std::move(te_bias)); accumulate_arg, use_split_accumulator_arg, num_math_sms);
wrappers.emplace_back(std::move(te_pre_gelu_out)); return D;
}
for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_batchgemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
return bias;
} }
#endif #endif
...@@ -175,7 +175,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -175,7 +175,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
#ifdef USE_ROCM #ifdef USE_ROCM
m.def("te_general_batched_gemm", &te_general_batched_gemm, "Batched GEMM"); /// rocblas m.def("te_batchgemm_ts", &te_batchgemm_ts, "Batched GEMM"); /// rocblas
#endif #endif
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>()); py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment