Commit 9d0f1c9b authored by yuguo's avatar yuguo
Browse files

[DCU] add batchgemm test

parent e8f92b93
...@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test ...@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Installation script.""" """Installation script."""
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v # NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# VTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel # NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
import os import os
import sys import sys
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections import OrderedDict
import math
import os
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
fp8_autocast,
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
attention_mask_func,
is_bf16_compatible,
)
from transformer_engine.pytorch import (
DotProductAttention,
LayerNormLinear,
LayerNormMLP,
Linear,
GroupedLinear,
BatchedLinear,
MultiheadAttention,
RMSNorm,
TransformerLayer,
LayerNorm,
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm, batchgemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
if torch_version() >= (2, 7, 0):
torch._dynamo.config.recompile_limit = 16
else:
torch._dynamo.config.cache_size_limit = 16
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
}
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"]
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
]
def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol)
if rtol is not None:
tols["rtol"] = rtol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def _test_batched_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
assert config.seq_len % num_gemms == 0
m_splits = torch.tensor([config.seq_len // num_gemms for i in range(num_gemms)])
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, BatchedLinear):
m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist())
else:
out = torch.cat(
[
block[i](inp)
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
]
)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [4, 8])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_batched_linear_accuracy(
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
parallel_mode=None,
):
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for batched linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for batched linear.")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
batched_linear = BatchedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms // batch_num):
weight = getattr(batched_linear, f"weight{i}").clone()
# bias = getattr(batched_linear, f"bias{i}").clone()
if fuse_wgrad_accumulation:
weight_i = getattr(batched_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
for j in range(batch_num):
sequential_linear[i * batch_num + j].weight = Parameter(weight[weight.shape[0] // batch_num * j : weight.shape[0] // batch_num * (j + 1)].clone())
# sequential_linear[i * batch_num + j].bias = Parameter(bias[bias.shape[0] // batch_num * j : bias.shape[0] // batch_num * (j + 1)].clone())
if fuse_wgrad_accumulation:
sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone()
outputs_ref = _test_batched_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
outputs = _test_batched_linear_accuracy(
batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3)
if __name__ == "__main__":
test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True)
...@@ -778,6 +778,9 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -778,6 +778,9 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias); const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out); Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace); Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
}
int m, n, k; int m, n, k;
if (!transa && transb) { if (!transa && transb) {
......
...@@ -18,7 +18,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase ...@@ -18,7 +18,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
"general_grouped_gemm", "general_grouped_gemm",
"general_batched_gemm", "batchgemm",
] ]
...@@ -226,84 +226,78 @@ def general_grouped_gemm( ...@@ -226,84 +226,78 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
def general_batched_gemm( def batchgemm(
A: List[torch.Tensor], A: List[torch.Tensor],
B: List[torch.Tensor], B: List[torch.Tensor],
out: List[torch.Tensor], out: List[torch.Tensor],
out_dtype: torch.dtype, dtype: torch.dtype,
workspaces: List[torch.Tensor], workspaces: List[torch.Tensor],
layout: str = "TN",
m_splits: Optional[List[int]] = None,
gelu: bool = False, gelu: bool = False,
grad=False, gelu_input: Optional[List[torch.Tensor]] = None,
grad: bool = False,
accumulate: bool = False, accumulate: bool = False,
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None, bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False, use_bias: bool = False,
use_split_accumulator: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]:
D_dtype: Optional[tex.DType] = None, """Non FP8 batch GEMM."""
single_output=False,
) -> Tuple[List[torch.Tensor], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
"""
num_gemms = len(A)
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
if gelu and not grad:
gelu_input = [
torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out
]
elif not gelu:
gelu_input = empty_tensors
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if isinstance(A[0], Float8TensorBase):
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a._data)
assert_dim_for_fp8_exec(b._data)
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
# Use bfloat16 as default bias_dtype
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
if grad and use_bias: if grad and use_bias:
grad_bias = [ grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms) torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
] ]
else: else:
grad_bias = empty_tensors grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors bias = bias if use_bias else empty_tensors
assert (
A[0].dtype == dtype and B[0].dtype == dtype
), f"Expected dtype={dtype}, but found A.dtype={A[0].dtype} and B.dtype={B[0].dtype}"
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out[0].dtype]
if use_bias: if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype] bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else: else:
bias_dtype = TE_DType[torch.bfloat16] bias_dtype = output_dtype
tex.te_batchgemm_ts(
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
] # this should differ with respect to single output
bias = tex.te_general_batched_gemm(
A, A,
empty_tensor,
0, # A_offset
input_dtype,
transa, transa,
B, B,
empty_tensor,
0, # B_offset
input_dtype,
transb, transb,
out, out,
out_dtype, 0, # out_offset
m_splits, empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias, grad_bias if grad else bias,
bias_dtype, bias_dtype,
single_output, gelu_input, # gelu_input
gelu_input, # this is pre_gelu_out grad,
grad, # grad
workspaces, workspaces,
workspaces[0].shape[0], workspaces[0].shape[0],
accumulate, accumulate,
use_split_accumulator, False, # use_split_accumulator
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
) )
return out, bias, gelu_input return out, grad_bias, gelu_input
...@@ -102,13 +102,23 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -102,13 +102,23 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool use_split_accumulator, int math_sm_count); bool use_split_accumulator, int math_sm_count);
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
std::optional<std::vector<at::Tensor>> te_general_batched_gemm( void te_batchgemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb, transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type, at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias, bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out, transformer_engine::DType D_type, at::Tensor D_amax,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
bool use_split_accumulator, int math_sm_count); std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
std::vector<at::Tensor> te_batchgemm_ts(
std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator);
#endif #endif
/*************************************************************************************************** /***************************************************************************************************
......
...@@ -424,123 +424,104 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -424,123 +424,104 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
} }
#ifdef USE_ROCM #ifdef USE_ROCM
std::optional<std::vector<at::Tensor>> te_general_batched_gemm( void te_batchgemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb, transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type, at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias, bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out, transformer_engine::DType D_type, at::Tensor D_amax,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
bool use_split_accumulator, int math_sm_count) { std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine; using namespace transformer_engine;
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector, std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
te_pre_gelu_out_vector, te_workspace_vector; std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
std::vector<TensorWrapper> wrappers; auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
std::vector<at::Tensor> D_vectors; transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
auto none = py::none(); tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
std::vector<size_t> single_output_begins; return tensor_wrappers.back().data();
std::vector<size_t> single_output_ends; };
int slicing_dim;
if (single_output && D == std::nullopt) {
NVTE_ERROR("not implemented, D should be allocated for single output case.");
}
void* output_data_ptr;
if (single_output) {
output_data_ptr = (*D)[0].data_ptr();
}
for (size_t i = 0; i < A.size(); i++) { for (size_t i = 0; i < A.size(); i++) {
auto te_A = makeTransformerEngineTensor(A[i], none); if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) {
auto te_B = makeTransformerEngineTensor(B[i], none); if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_();
if (bias[i].data_ptr() != nullptr) bias[i].zero_();
// if there is single output if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_();
at::Tensor out_tensor;
auto size_t_shape =
pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb);
bool D_numel_is_zero = false;
std::vector<int64_t> D_shape;
for (size_t t : size_t_shape) {
D_shape.push_back(t);
if (t == 0) {
D_numel_is_zero = true;
}
}
auto dtype = GetATenDType(D_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
if (single_output) {
if (output_data_ptr == nullptr) {
out_tensor = at::empty(D_shape, opts);
} else {
// We need to check !D_numel_is_zero because if the final input portion has zero elements,
// output_data_ptr would point beyond the allocated memory of D. This would cause
// at::from_blob to fail as it would reference memory not allocated by CUDA.
if (!D_numel_is_zero) {
out_tensor = at::from_blob(output_data_ptr, D_shape, opts);
}
}
char* char_ptr = reinterpret_cast<char*>(output_data_ptr);
char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size();
output_data_ptr = reinterpret_cast<void*>(char_ptr);
D_vectors.emplace_back(out_tensor);
} else {
if (D == std::nullopt) {
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
out_tensor = at::empty(D_shape, opts);
D_vectors.emplace_back(out_tensor);
} else {
out_tensor = (*D)[i];
}
}
if (te_A.numel() == 0 || te_B.numel() == 0) {
if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_();
if (bias[i].numel() != 0 && grad) {
bias[i].zero_();
}
if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_();
continue; continue;
} }
te_A.emplace_back(make_tensor(
auto te_D = makeTransformerEngineTensor(out_tensor); A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
auto te_bias = makeTransformerEngineTensor(bias[i]); A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i)));
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
D[i].data_ptr(), {static_cast<size_t>(D[i].size(0)), static_cast<size_t>(D[i].size(1))},
D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(te_pre_gelu_out.size(0))} ? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(te_pre_gelu_out.size(0)), : std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(te_pre_gelu_out.size(1))}; static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
DType gelu_type = bias_type; pre_gelu_out[i].data_ptr(), gelu_shape,
te_pre_gelu_out = GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type);
te_A_vector.emplace_back(te_A.data());
te_B_vector.emplace_back(te_B.data());
te_D_vector.emplace_back(te_D.data());
te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());
wrappers.emplace_back(std::move(te_A));
wrappers.emplace_back(std::move(te_B));
wrappers.emplace_back(std::move(te_D));
wrappers.emplace_back(std::move(te_bias));
wrappers.emplace_back(std::move(te_pre_gelu_out));
} }
for (size_t i = 0; i < workspace.size(); i++) { for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte); te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
te_workspace_vector.emplace_back(wsp.data()); nullptr, nullptr, nullptr));
wrappers.emplace_back(std::move(wsp));
} }
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_batchgemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), nvte_multi_stream_cublas_batchgemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(), te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
te_A_vector.size(), transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream()); math_sm_count, at::cuda::getCurrentCUDAStream());
return bias; }
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
return static_cast<transformer_engine::DType>(dtype);
} else {
NVTE_ERROR("Type not supported.");
}
}
std::vector<at::Tensor> te_batchgemm_ts(
std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int sm_count = transformer_engine::cuda::sm_count();
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
te_batchgemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse,
B_offset, B_type_arg, transb_arg, D, D_offset, D_scale, D_type_arg, D_amax, bias,
bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
} }
#endif #endif
...@@ -175,7 +175,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -175,7 +175,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
#ifdef USE_ROCM #ifdef USE_ROCM
m.def("te_general_batched_gemm", &te_general_batched_gemm, "Batched GEMM"); /// rocblas m.def("te_batchgemm_ts", &te_batchgemm_ts, "Batched GEMM"); /// rocblas
#endif #endif
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>()); py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""BatchedLinear API""" """Linear API"""
from typing import Union, Optional, Callable, Tuple, List
import os import os
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -16,7 +18,7 @@ from .base import ( ...@@ -16,7 +18,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ..fp8 import FP8GlobalStateManager from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
cast_if_needed, cast_if_needed,
...@@ -32,27 +34,42 @@ from ..distributed import ( ...@@ -32,27 +34,42 @@ from ..distributed import (
in_fp8_activation_recompute_phase, in_fp8_activation_recompute_phase,
) )
from ..cpp_extensions import ( from ..cpp_extensions import (
general_batched_gemm, batchgemm,
) )
from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..tensor.float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ..cpu_offload import is_cpu_offload_enabled
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
from ..tensor.quantized_tensor import ( _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
QuantizedTensor, # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
Quantizer, _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
prepare_for_saving, log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
restore_from_saved, log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
) )
__all__ = ["BatchedLinear"] __all__ = ["BatchedLinear"]
class _BatchedLinear(torch.autograd.Function): """
"""BatchedLinear semi-top level module The offset for fp8_meta_index.
_GEMM_INPUT = 0
_GEMM_WEIGHT = num_gemms
_GEMM_OUTPUT = 2 * num_gemms
Must be properly set in BatchedLinear's initialization.
"""
_GEMM_INPUT = 0
_GEMM_WEIGHT = 0
_GEMM_OUTPUT = 0
_GRAD_OUTPUT = 0
class _BatchLinear(torch.autograd.Function):
"""BatchLinear semi-top level module
Calls custom cuda extensions. Calls custom cuda extensions.
""" """
...@@ -65,205 +82,137 @@ class _BatchedLinear(torch.autograd.Function): ...@@ -65,205 +82,137 @@ class _BatchedLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
input_quantizers: List[Quantizer], fp8_meta: Dict[str, Any],
weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer],
grad_output_quantizers: List[Quantizer],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
cpu_offloading: bool, cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool, sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool, is_grad_enabled: bool,
module, *weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
skip_fp8_weight_update,
*weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2")) batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
logger = logging.getLogger("BatchLinear")
# pylint: disable=missing-function-docstring
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
device = inp.device biases = weights_and_biases[2 * num_gemms :]
# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("BatchedLinear does not yet support MXFP8")
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("BatchedLinear does not yet support Float8 Current Scaling")
# TODO Support Float8 Delayed Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().delayed():
raise NotImplementedError("BatchedLinear does not yet support Float8 Delayed Scaling")
# TODO Support Float8 Per Tensor Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling():
raise NotImplementedError("BatchedLinear does not yet support Float8 Per Tensor Scaling")
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weights[0].shape[-1] in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits) inputmats = torch.split(inp.view(-1, in_features), m_splits)
if fp8: if fp8:
assert_dim_for_fp8_exec(*inputmats, *weights) assert False, "BatchLinear does not support fp8 yet."
# Cast input to expected dtype # Cast input to expected dtype
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = [] inputmats = []
inputmats_t = []
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
inputmats = inputmats_no_fp8
weight_requires_grad = weights[0].requires_grad logger.debug("Running forward in %s", activation_dtype)
if input_quantizers[0] is not None:
for input_quantizer in input_quantizers:
input_quantizer.set_usage(
rowwise=True,
columnwise=(is_grad_enabled and weight_requires_grad),
)
columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage:
columnwise_usage = (
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
)
if weight_quantizers[0] is not None:
for weight_quantizer in weight_quantizers:
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
if output_quantizers[0] is not None:
for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False)
if fp8:
inputmats = tex.fused_multi_quantize(
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
)
weights_fp8 = []
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
if not isinstance(weights[0], QuantizedTensor):
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms):
weight_fp8 = module.get_weight_workspace(
tensor=weights[i],
quantizer=weight_quantizers[i],
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
)
weights_fp8.append(weight_fp8)
else:
weights_fp8 = weights
else:
inputmats = inputmats_no_fp8
bias_dtype = activation_dtype
weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights]
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases # Cast for native AMP
weights = [cast_if_needed(w, activation_dtype) for w in weights]
biases = (
[cast_if_needed(bias, activation_dtype) for bias in biases] if use_bias else biases
)
assert weights[0].size(0) % batch_num == 0, "weights[0].size(0) should be batch_num multiply."
assert weights_fp8[0].size(0) % batch_num == 0, "weights_fp8[0].size(0) should be batch_num multiply."
out = torch.empty( out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0) // batch_num], [sum(m_splits), int(weights[0].size(0) // batch_num)],
dtype=activation_dtype, dtype=activation_dtype,
device=device, device=inputmats[0].device,
) )
_ = batchgemm(
_ = general_batched_gemm( weights,
weights_fp8,
inputmats, inputmats,
[out], torch.split(out, m_splits),
activation_dtype, activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(), get_multi_stream_cublas_batchgemm_workspace(),
single_output=True,
m_splits=m_splits,
bias=biases, bias=biases,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
) )
if fp8_calibration:
for i in range(num_gemms):
# amax of input
for i in range(num_gemms):
input_quantizers[i].calibrate(inputmats[i])
for i in range(num_gemms):
weight_quantizers[i].calibrate(weights[i])
if is_grad_enabled: if is_grad_enabled:
saved_inputmats = [None] * num_gemms
ctx.weights_shape_1 = weights[0].shape[1] saved_inputmats_t = [None] * num_gemms
if weights[0].requires_grad:
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) saved_inputmats = inputmats_no_fp8
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects if cpu_offloading:
if fuse_wgrad_accumulation:
ctx.weights_requires_grad = weights[0].requires_grad for w in weights:
if fuse_wgrad_accumulation and ctx.weights_requires_grad: w.main_grad.weight_offloading = True
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] for w in weights:
else: w.weight_offloading = True
ctx.main_grads = [None] * num_gemms for t in saved_inputmats:
ctx.device = device if t is not None:
ctx.grad_output_quantizers = grad_output_quantizers t.activation_offloading = True
ctx.save_for_backward(
None,
*saved_inputmats,
*saved_inputmats_t,
*weights,
*weights_fp8,
*[
w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None
for w in weights
],
)
ctx.m_splits = m_splits ctx.m_splits = m_splits
ctx.num_gemms = num_gemms ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1]) return out.view(-1, *inp.shape[1:-1], out.shape[-1])
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring logger = logging.getLogger("BatchLinear")
with torch.cuda.nvtx.range("_BatchedLinear_backward"):
saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) with torch.cuda.nvtx.range("_BatchLinear_backward"):
N = ctx.num_gemms (
inputmats = saved_tensors[:N] fwd_scale_inverses,
weights = saved_tensors[N : 2 * N] *saved_tensors,
biases = saved_tensors[2 * N : 3 * N] ) = ctx.saved_tensors
main_grads = ctx.main_grads inputmats = saved_tensors[: ctx.num_gemms]
inputmats_t = saved_tensors[ctx.num_gemms : 2 * ctx.num_gemms]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms]
weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms]
main_grads = saved_tensors[4 * ctx.num_gemms :]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in ctx.num_gemms: for i in ctx.num_gemms:
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w = torch.nn.Parameter(weights[i], False)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
# preprocess grad_output
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
grad_output_mats = torch.split( grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits
) )
grad_output = [None] * ctx.num_gemms grad_output_c = [None] * ctx.num_gemms
grad_output_t = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.fp8:
if ctx.use_bias:
for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output_mats[i], ctx.grad_output_quantizers[i]
)
else:
grad_output = tex.fused_multi_quantize(
grad_output_mats,
None,
ctx.grad_output_quantizers,
TE_DType[ctx.activation_dtype],
)
else:
grad_output = grad_output_mats
if ctx.is_first_microbatch is not None: if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = ( accumulate_wgrad_into_param_main_grad = (
...@@ -273,114 +222,105 @@ class _BatchedLinear(torch.autograd.Function): ...@@ -273,114 +222,105 @@ class _BatchedLinear(torch.autograd.Function):
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.requires_dgrad: if ctx.requires_dgrad:
logger.debug("Running backward in %s", ctx.activation_dtype)
dgrad = torch.empty( dgrad = torch.empty(
(sum(ctx.m_splits), ctx.weights_shape_1), (sum(ctx.m_splits), int(weights[0].size(1))),
dtype=ctx.activation_dtype, dtype=ctx.activation_dtype,
device=ctx.device, device=grad_output.device,
) )
batchgemm(
general_batched_gemm(
weights, weights,
grad_output, grad_output_mats,
[dgrad], torch.split(dgrad, ctx.m_splits),
ctx.activation_dtype, ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(), get_multi_stream_cublas_batchgemm_workspace(),
single_output=True,
layout="NN", layout="NN",
m_splits=ctx.m_splits,
grad=True, grad=True,
use_split_accumulator=_2X_ACC_DGRAD,
) )
if ctx.weights_requires_grad: if weights[0].requires_grad:
if ctx.fuse_wgrad_accumulation: if ctx.fuse_wgrad_accumulation:
wgrad_list = main_grads wgrad_list = [w.main_grad for w in weights]
else: else:
wgrad_list = [ wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights for w in weights
] ]
# WGRAD # WGRAD
_, grad_biases_, _ = general_batched_gemm( _, grad_biases, _ = batchgemm(
inputmats, inputmats,
grad_output, grad_output_mats,
wgrad_list, wgrad_list,
ctx.activation_dtype, ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(), get_multi_stream_cublas_batchgemm_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
m_splits=ctx.m_splits, use_bias=ctx.use_bias,
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=_2X_ACC_WGRAD,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
) )
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# Deallocate input tensor # Deallocate input tensor
clear_tensor_data(*inputmats) clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t)
def handle_custom_ddp_from_mcore(w, wgrad): if not ctx.use_bias:
if ctx.weights_requires_grad: grad_biases = [None] * ctx.num_gemms
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
wgrad = None
return wgrad
wgrad_list = [ def handle_custom_ddp_from_mcore(w, wgrad):
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) if w.requires_grad:
] if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else: else:
wgrad_list = [None] * ctx.num_gemms wgrad = None
return wgrad
if not ctx.use_bias: wgrad_list = [
grad_biases = [None] * ctx.num_gemms handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None, None, # m_splits
None, None, # use_bias
None, None, # is_first_microbatch
None, None, # fp8
None, None, # fp8_calibration
None, None, # fp8_meta
None, None, # fuse_wgrad_accumulation
None, None, # cpu_offloading
None, None, # tp_group
None, None, # tp_size
None, None, # sequence_parallel
None, None, # tensor_parallel
None, None, # activation_dtype
None, None, # parallel_mode
None, # is_grad_enabled
None, # is_grad_enabled None, # is_grad_enabled
*wgrad_list, *wgrad_list,
*([None] * ctx.num_gemms), # weights_fp8
*grad_biases, *grad_biases,
) )
class BatchedLinear(TransformerEngineBaseModule): class BatchedLinear(TransformerEngineBaseModule):
"""Applies linear transformations to the incoming data list """Applies linear transformations to the incoming data list
:math:`y_i = x_iA_i^T + b_i` in a batched way. :math:`y_i = x_iA_i^T + b_i` in a batched way.
...@@ -399,14 +339,31 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -399,14 +339,31 @@ class BatchedLinear(TransformerEngineBaseModule):
used for initializing weights in the following way: `init_method(weight)`. used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None` get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initializing weights. used to get the random number generator state tracker for initilizeing weights.
rng_tracker_name : str, default = `None` rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker. the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this BatchedLinear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = 'False'
...@@ -426,7 +383,6 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -426,7 +383,6 @@ class BatchedLinear(TransformerEngineBaseModule):
would not fit in GPU memory. would not fit in GPU memory.
""" """
def __init__( def __init__(
self, self,
num_gemms: int, num_gemms: int,
...@@ -462,15 +418,15 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -462,15 +418,15 @@ class BatchedLinear(TransformerEngineBaseModule):
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.ub_overlap_rs = ub_overlap_rs self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag self.ub_overlap_ag = ub_overlap_ag
if ub_overlap_rs or ub_overlap_ag:
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name self.ub_name = ub_name
assert (
not ub_overlap_rs and not ub_overlap_ag
), "BatchedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self._offsets = {"input": 0, "weight": self.num_gemms, "output": 2 * self.num_gemms, "grad_output": 0} global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, self.num_gemms, 2 * self.num_gemms
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -492,7 +448,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -492,7 +448,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
# In batchgemm, we use batch=batch_num to launch blas batchgemm # In batchgemm, we use batch=batch_num to launch blas batchgemm
for i in range(self.num_gemms): for i in range(int(self.num_gemms)):
# Construct weight parameter # Construct weight parameter
self.register_parameter( self.register_parameter(
f"weight{i}", f"weight{i}",
...@@ -506,7 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -506,7 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule):
), ),
init_fn=init_method, init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i, fp8_meta_index=_GEMM_WEIGHT + i,
) )
# Construct bias parameters if needed # Construct bias parameters if needed
...@@ -515,7 +471,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -515,7 +471,7 @@ class BatchedLinear(TransformerEngineBaseModule):
f"bias{i}", f"bias{i}",
torch.nn.Parameter( torch.nn.Parameter(
torch.empty( torch.empty(
self.out_features, self.out_features * self.batch_num,
device=device, device=device,
dtype=params_dtype, dtype=params_dtype,
), ),
...@@ -525,11 +481,15 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -525,11 +481,15 @@ class BatchedLinear(TransformerEngineBaseModule):
else: else:
bias = torch.Tensor().to(dtype=params_dtype, device=device) bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, f"bias{i}", bias) setattr(self, f"bias{i}", bias)
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms) self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=device == "meta") self.reset_parameters(defer_init=(device == "meta"))
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
...@@ -543,7 +503,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -543,7 +503,7 @@ class BatchedLinear(TransformerEngineBaseModule):
if not defer_init: if not defer_init:
# Set parallelism attributes for linear weights # Set parallelism attributes for linear weights
for i in range(self.num_gemms): for i in range(int(self.num_gemms)):
set_tensor_model_parallel_attributes( set_tensor_model_parallel_attributes(
tensor=getattr(self, f"weight{i}"), tensor=getattr(self, f"weight{i}"),
is_parallel=True, is_parallel=True,
...@@ -553,15 +513,15 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -553,15 +513,15 @@ class BatchedLinear(TransformerEngineBaseModule):
# Set parallelism attributes for linear biases # Set parallelism attributes for linear biases
if self.use_bias: if self.use_bias:
for bias in self.bias_names: for i in range(self.num_gemms):
if self.parallel_mode == "row": if self.parallel_mode == "row":
setattr( setattr(
getattr(self, bias), getattr(self, f"bias{i}"),
"sequence_parallel", "sequence_parallel",
self.sequence_parallel, self.sequence_parallel,
) )
elif self.parallel_mode == "column": elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1)
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
...@@ -593,57 +553,33 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -593,57 +553,33 @@ class BatchedLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
assert not isinstance( assert not isinstance(
inp, Float8Tensor inp, Float8Tensor
), "BatchedLinear doesn't support input tensor in FP8." ), "BatchedLinear doesn't support input tensor in FP8."
m_splits_batch_gemm = [x * self.batch_num for x in m_splits[0:int(self.num_gemms)]] m_splits_batch_gemm = [x * self.batch_num for x in m_splits[0:int(self.num_gemms)]]
assert len(m_splits_batch_gemm) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits_batch_gemm) == self.num_gemms, "Number of splits should match number of GEMMs."
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None: if skip_fp8_weight_update is not None:
is_first_microbatch = False is_first_microbatch = False
with self.prepare_forward(inp=inp, num_gemms=self.num_gemms) as inp:
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = [getattr(self, f"weight{i}") for i in range(int(self.num_gemms))]
bias_tensors = [getattr(self, f"bias{i}") for i in range(int(self.num_gemms))]
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8: if not self.fp8:
weight_tensors = [ weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors w.from_float8() if isinstance(w, Float8Tensor) else w for w in weight_tensors
] ]
input_quantizers, weight_quantizers, output_quantizers = ( weight_tensors_fp8 = [None] * int(self.num_gemms)
[None] * self.num_gemms,
[None] * self.num_gemms, from ..cpu_offload import CPUOffloadEnabled
[None] * self.num_gemms,
)
grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["input"] + i]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
input_quantizers[i].internal = True
weight_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["weight"] + i]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][self._offsets["input"] + i]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
if torch.is_grad_enabled(): if torch.is_grad_enabled():
linear_fn = _BatchedLinear.apply linear_fn = _BatchLinear.apply
args = [] args = []
else: else:
linear_fn = _BatchedLinear.forward linear_fn = _BatchLinear.forward
args = [None] args = [None]
args += ( args += (
inp, inp,
...@@ -652,22 +588,22 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -652,22 +588,22 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
input_quantizers, self.fp8_meta,
weight_quantizers,
output_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(), CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel, self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self,
skip_fp8_weight_update,
*weight_tensors, *weight_tensors,
*weight_tensors_fp8,
*bias_tensors, *bias_tensors,
) )
out = linear_fn(*args) out = linear_fn(*args)
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
out_shape = out.shape out_shape = out.shape
out = torch.cat( out = torch.cat(
...@@ -678,7 +614,6 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -678,7 +614,6 @@ class BatchedLinear(TransformerEngineBaseModule):
) )
] ]
).view(out_shape) ).view(out_shape)
if self.return_bias: if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment