Commit 9d0f1c9b authored by yuguo's avatar yuguo
Browse files

[DCU] add batchgemm test

parent e8f92b93
......@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
......
......@@ -4,7 +4,7 @@
"""Installation script."""
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# VTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
import os
import sys
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections import OrderedDict
import math
import os
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
fp8_autocast,
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
attention_mask_func,
is_bf16_compatible,
)
from transformer_engine.pytorch import (
DotProductAttention,
LayerNormLinear,
LayerNormMLP,
Linear,
GroupedLinear,
BatchedLinear,
MultiheadAttention,
RMSNorm,
TransformerLayer,
LayerNorm,
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm, batchgemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
if torch_version() >= (2, 7, 0):
torch._dynamo.config.recompile_limit = 16
else:
torch._dynamo.config.cache_size_limit = 16
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
}
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"]
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
]
def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol)
if rtol is not None:
tols["rtol"] = rtol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def _test_batched_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
assert config.seq_len % num_gemms == 0
m_splits = torch.tensor([config.seq_len // num_gemms for i in range(num_gemms)])
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, BatchedLinear):
m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist())
else:
out = torch.cat(
[
block[i](inp)
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
]
)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [4, 8])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_batched_linear_accuracy(
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
parallel_mode=None,
):
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for batched linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for batched linear.")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
batched_linear = BatchedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms // batch_num):
weight = getattr(batched_linear, f"weight{i}").clone()
# bias = getattr(batched_linear, f"bias{i}").clone()
if fuse_wgrad_accumulation:
weight_i = getattr(batched_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
for j in range(batch_num):
sequential_linear[i * batch_num + j].weight = Parameter(weight[weight.shape[0] // batch_num * j : weight.shape[0] // batch_num * (j + 1)].clone())
# sequential_linear[i * batch_num + j].bias = Parameter(bias[bias.shape[0] // batch_num * j : bias.shape[0] // batch_num * (j + 1)].clone())
if fuse_wgrad_accumulation:
sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone()
outputs_ref = _test_batched_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
outputs = _test_batched_linear_accuracy(
batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3)
if __name__ == "__main__":
test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True)
......@@ -778,6 +778,9 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
}
int m, n, k;
if (!transa && transb) {
......
......@@ -18,7 +18,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
__all__ = [
"general_gemm",
"general_grouped_gemm",
"general_batched_gemm",
"batchgemm",
]
......@@ -226,84 +226,78 @@ def general_grouped_gemm(
return out, bias, gelu_input
def general_batched_gemm(
def batchgemm(
A: List[torch.Tensor],
B: List[torch.Tensor],
out: List[torch.Tensor],
out_dtype: torch.dtype,
dtype: torch.dtype,
workspaces: List[torch.Tensor],
layout: str = "TN",
m_splits: Optional[List[int]] = None,
gelu: bool = False,
grad=False,
gelu_input: Optional[List[torch.Tensor]] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
single_output=False,
) -> Tuple[List[torch.Tensor], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
"""
num_gemms = len(A)
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 batch GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
num_gemms = len(A)
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if isinstance(A[0], Float8TensorBase):
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a._data)
assert_dim_for_fp8_exec(b._data)
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
# Use bfloat16 as default bias_dtype
if gelu and not grad:
gelu_input = [
torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out
]
elif not gelu:
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
assert (
A[0].dtype == dtype and B[0].dtype == dtype
), f"Expected dtype={dtype}, but found A.dtype={A[0].dtype} and B.dtype={B[0].dtype}"
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out[0].dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = TE_DType[torch.bfloat16]
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
] # this should differ with respect to single output
bias = tex.te_general_batched_gemm(
bias_dtype = output_dtype
tex.te_batchgemm_ts(
A,
empty_tensor,
0, # A_offset
input_dtype,
transa,
B,
empty_tensor,
0, # B_offset
input_dtype,
transb,
out,
out_dtype,
m_splits,
0, # out_offset
empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias,
bias_dtype,
single_output,
gelu_input, # this is pre_gelu_out
grad, # grad
gelu_input, # gelu_input
grad,
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
False, # use_split_accumulator
)
return out, bias, gelu_input
return out, grad_bias, gelu_input
......@@ -102,13 +102,23 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool use_split_accumulator, int math_sm_count);
#ifdef __HIP_PLATFORM_AMD__
std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
void te_batchgemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
std::vector<at::Tensor> te_batchgemm_ts(
std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator);
#endif
/***************************************************************************************************
......
......@@ -424,123 +424,104 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
}
#ifdef USE_ROCM
std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
void te_batchgemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers;
std::vector<at::Tensor> D_vectors;
auto none = py::none();
std::vector<size_t> single_output_begins;
std::vector<size_t> single_output_ends;
int slicing_dim;
if (single_output && D == std::nullopt) {
NVTE_ERROR("not implemented, D should be allocated for single output case.");
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < A.size(); i++) {
if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) {
if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_();
if (bias[i].data_ptr() != nullptr) bias[i].zero_();
if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_();
continue;
}
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
D[i].data_ptr(), {static_cast<size_t>(D[i].size(0)), static_cast<size_t>(D[i].size(1))},
D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));
void* output_data_ptr;
if (single_output) {
output_data_ptr = (*D)[0].data_ptr();
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
}
for (size_t i = 0; i < workspace.size(); i++) {
te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
nullptr, nullptr, nullptr));
}
for (size_t i = 0; i < A.size(); i++) {
auto te_A = makeTransformerEngineTensor(A[i], none);
auto te_B = makeTransformerEngineTensor(B[i], none);
nvte_multi_stream_cublas_batchgemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
// if there is single output
at::Tensor out_tensor;
auto size_t_shape =
pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb);
bool D_numel_is_zero = false;
std::vector<int64_t> D_shape;
for (size_t t : size_t_shape) {
D_shape.push_back(t);
if (t == 0) {
D_numel_is_zero = true;
}
}
auto dtype = GetATenDType(D_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
if (single_output) {
if (output_data_ptr == nullptr) {
out_tensor = at::empty(D_shape, opts);
} else {
// We need to check !D_numel_is_zero because if the final input portion has zero elements,
// output_data_ptr would point beyond the allocated memory of D. This would cause
// at::from_blob to fail as it would reference memory not allocated by CUDA.
if (!D_numel_is_zero) {
out_tensor = at::from_blob(output_data_ptr, D_shape, opts);
}
}
char* char_ptr = reinterpret_cast<char*>(output_data_ptr);
char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size();
output_data_ptr = reinterpret_cast<void*>(char_ptr);
D_vectors.emplace_back(out_tensor);
} else {
if (D == std::nullopt) {
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
out_tensor = at::empty(D_shape, opts);
D_vectors.emplace_back(out_tensor);
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
return static_cast<transformer_engine::DType>(dtype);
} else {
out_tensor = (*D)[i];
}
}
if (te_A.numel() == 0 || te_B.numel() == 0) {
if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_();
if (bias[i].numel() != 0 && grad) {
bias[i].zero_();
NVTE_ERROR("Type not supported.");
}
if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_();
continue;
}
auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]);
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]);
}
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(te_pre_gelu_out.size(0))}
: std::vector<size_t>{static_cast<size_t>(te_pre_gelu_out.size(0)),
static_cast<size_t>(te_pre_gelu_out.size(1))};
std::vector<at::Tensor> te_batchgemm_ts(
std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
DType gelu_type = bias_type;
te_pre_gelu_out =
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
te_A_vector.emplace_back(te_A.data());
te_B_vector.emplace_back(te_B.data());
te_D_vector.emplace_back(te_D.data());
te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());
const int sm_count = transformer_engine::cuda::sm_count();
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
wrappers.emplace_back(std::move(te_A));
wrappers.emplace_back(std::move(te_B));
wrappers.emplace_back(std::move(te_D));
wrappers.emplace_back(std::move(te_bias));
wrappers.emplace_back(std::move(te_pre_gelu_out));
}
for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_batchgemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
return bias;
te_batchgemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse,
B_offset, B_type_arg, transb_arg, D, D_offset, D_scale, D_type_arg, D_amax, bias,
bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}
#endif
......@@ -175,7 +175,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
#ifdef USE_ROCM
m.def("te_general_batched_gemm", &te_general_batched_gemm, "Batched GEMM"); /// rocblas
m.def("te_batchgemm_ts", &te_batchgemm_ts, "Batched GEMM"); /// rocblas
#endif
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
......
......@@ -2,9 +2,11 @@
#
# See LICENSE for license information.
"""BatchedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
"""Linear API"""
import os
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
import torch
import transformer_engine_torch as tex
......@@ -16,7 +18,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import FP8GlobalStateManager
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import (
divide,
cast_if_needed,
......@@ -32,27 +34,42 @@ from ..distributed import (
in_fp8_activation_recompute_phase,
)
from ..cpp_extensions import (
general_batched_gemm,
batchgemm,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.float8_tensor import Float8Tensor
from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
from ..float8_tensor import Float8Tensor
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
__all__ = ["BatchedLinear"]
class _BatchedLinear(torch.autograd.Function):
"""BatchedLinear semi-top level module
"""
The offset for fp8_meta_index.
_GEMM_INPUT = 0
_GEMM_WEIGHT = num_gemms
_GEMM_OUTPUT = 2 * num_gemms
Must be properly set in BatchedLinear's initialization.
"""
_GEMM_INPUT = 0
_GEMM_WEIGHT = 0
_GEMM_OUTPUT = 0
_GRAD_OUTPUT = 0
class _BatchLinear(torch.autograd.Function):
"""BatchLinear semi-top level module
Calls custom cuda extensions.
"""
......@@ -65,205 +82,137 @@ class _BatchedLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
input_quantizers: List[Quantizer],
weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer],
grad_output_quantizers: List[Quantizer],
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
module,
skip_fp8_weight_update,
*weights_and_biases,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
# pylint: disable=missing-function-docstring
logger = logging.getLogger("BatchLinear")
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:]
device = inp.device
# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("BatchedLinear does not yet support MXFP8")
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("BatchedLinear does not yet support Float8 Current Scaling")
# TODO Support Float8 Delayed Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().delayed():
raise NotImplementedError("BatchedLinear does not yet support Float8 Delayed Scaling")
# TODO Support Float8 Per Tensor Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling():
raise NotImplementedError("BatchedLinear does not yet support Float8 Per Tensor Scaling")
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
biases = weights_and_biases[2 * num_gemms :]
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits)
if fp8:
assert_dim_for_fp8_exec(*inputmats, *weights)
assert False, "BatchLinear does not support fp8 yet."
# Cast input to expected dtype
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
inputmats_t = []
weight_requires_grad = weights[0].requires_grad
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
inputmats = inputmats_no_fp8
if input_quantizers[0] is not None:
for input_quantizer in input_quantizers:
input_quantizer.set_usage(
rowwise=True,
columnwise=(is_grad_enabled and weight_requires_grad),
)
columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage:
columnwise_usage = (
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
)
if weight_quantizers[0] is not None:
for weight_quantizer in weight_quantizers:
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
if output_quantizers[0] is not None:
for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False)
logger.debug("Running forward in %s", activation_dtype)
if fp8:
inputmats = tex.fused_multi_quantize(
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
)
weights_fp8 = []
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
if not isinstance(weights[0], QuantizedTensor):
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms):
weight_fp8 = module.get_weight_workspace(
tensor=weights[i],
quantizer=weight_quantizers[i],
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
# Cast for native AMP
weights = [cast_if_needed(w, activation_dtype) for w in weights]
biases = (
[cast_if_needed(bias, activation_dtype) for bias in biases] if use_bias else biases
)
weights_fp8.append(weight_fp8)
else:
weights_fp8 = weights
assert weights[0].size(0) % batch_num == 0, "weights[0].size(0) should be batch_num multiply."
else:
inputmats = inputmats_no_fp8
bias_dtype = activation_dtype
weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights]
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
assert weights_fp8[0].size(0) % batch_num == 0, "weights_fp8[0].size(0) should be batch_num multiply."
out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0) // batch_num],
[sum(m_splits), int(weights[0].size(0) // batch_num)],
dtype=activation_dtype,
device=device,
device=inputmats[0].device,
)
_ = general_batched_gemm(
weights_fp8,
_ = batchgemm(
weights,
inputmats,
[out],
torch.split(out, m_splits),
activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
single_output=True,
m_splits=m_splits,
bias=biases,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
if fp8_calibration:
for i in range(num_gemms):
# amax of input
for i in range(num_gemms):
input_quantizers[i].calibrate(inputmats[i])
for i in range(num_gemms):
weight_quantizers[i].calibrate(weights[i])
if is_grad_enabled:
ctx.weights_shape_1 = weights[0].shape[1]
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)]
else:
ctx.main_grads = [None] * num_gemms
ctx.device = device
ctx.grad_output_quantizers = grad_output_quantizers
saved_inputmats = [None] * num_gemms
saved_inputmats_t = [None] * num_gemms
if weights[0].requires_grad:
saved_inputmats = inputmats_no_fp8
if cpu_offloading:
if fuse_wgrad_accumulation:
for w in weights:
w.main_grad.weight_offloading = True
for w in weights:
w.weight_offloading = True
for t in saved_inputmats:
if t is not None:
t.activation_offloading = True
ctx.save_for_backward(
None,
*saved_inputmats,
*saved_inputmats_t,
*weights,
*weights_fp8,
*[
w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None
for w in weights
],
)
ctx.m_splits = m_splits
ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_BatchedLinear_backward"):
saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
N = ctx.num_gemms
inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N]
biases = saved_tensors[2 * N : 3 * N]
main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
logger = logging.getLogger("BatchLinear")
with torch.cuda.nvtx.range("_BatchLinear_backward"):
(
fwd_scale_inverses,
*saved_tensors,
) = ctx.saved_tensors
inputmats = saved_tensors[: ctx.num_gemms]
inputmats_t = saved_tensors[ctx.num_gemms : 2 * ctx.num_gemms]
weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms]
weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms]
main_grads = saved_tensors[4 * ctx.num_gemms :]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in ctx.num_gemms:
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w = torch.nn.Parameter(weights[i], False)
w.main_grad = main_grads[i]
weights[i] = w
# preprocess grad_output
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits
)
grad_output = [None] * ctx.num_gemms
grad_output_c = [None] * ctx.num_gemms
grad_output_t = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms
if ctx.fp8:
if ctx.use_bias:
for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output_mats[i], ctx.grad_output_quantizers[i]
)
else:
grad_output = tex.fused_multi_quantize(
grad_output_mats,
None,
ctx.grad_output_quantizers,
TE_DType[ctx.activation_dtype],
)
else:
grad_output = grad_output_mats
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
......@@ -273,58 +222,52 @@ class _BatchedLinear(torch.autograd.Function):
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.requires_dgrad:
logger.debug("Running backward in %s", ctx.activation_dtype)
dgrad = torch.empty(
(sum(ctx.m_splits), ctx.weights_shape_1),
(sum(ctx.m_splits), int(weights[0].size(1))),
dtype=ctx.activation_dtype,
device=ctx.device,
device=grad_output.device,
)
general_batched_gemm(
batchgemm(
weights,
grad_output,
[dgrad],
grad_output_mats,
torch.split(dgrad, ctx.m_splits),
ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
single_output=True,
layout="NN",
m_splits=ctx.m_splits,
grad=True,
use_split_accumulator=_2X_ACC_DGRAD,
)
if ctx.weights_requires_grad:
if weights[0].requires_grad:
if ctx.fuse_wgrad_accumulation:
wgrad_list = main_grads
wgrad_list = [w.main_grad for w in weights]
else:
wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights
]
# WGRAD
_, grad_biases_, _ = general_batched_gemm(
_, grad_biases, _ = batchgemm(
inputmats,
grad_output,
grad_output_mats,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
layout="NT",
grad=True,
m_splits=ctx.m_splits,
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=_2X_ACC_WGRAD,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# Deallocate input tensor
clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t)
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
def handle_custom_ddp_from_mcore(w, wgrad):
if ctx.weights_requires_grad:
if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
......@@ -350,37 +293,34 @@ class _BatchedLinear(torch.autograd.Function):
wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None, # is_grad_enabled
None, # m_splits
None, # use_bias
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # fp8_meta
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # is_grad_enabled
*wgrad_list,
*([None] * ctx.num_gemms), # weights_fp8
*grad_biases,
)
class BatchedLinear(TransformerEngineBaseModule):
"""Applies linear transformations to the incoming data list
:math:`y_i = x_iA_i^T + b_i` in a batched way.
......@@ -399,14 +339,31 @@ class BatchedLinear(TransformerEngineBaseModule):
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initializing weights.
used to get the random number generator state tracker for initilizeing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this BatchedLinear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
......@@ -426,7 +383,6 @@ class BatchedLinear(TransformerEngineBaseModule):
would not fit in GPU memory.
"""
def __init__(
self,
num_gemms: int,
......@@ -462,14 +418,14 @@ class BatchedLinear(TransformerEngineBaseModule):
self.apply_bias = bias and not return_bias
self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag
if ub_overlap_rs or ub_overlap_ag:
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name
assert (
not ub_overlap_rs and not ub_overlap_ag
), "BatchedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self._offsets = {"input": 0, "weight": self.num_gemms, "output": 2 * self.num_gemms, "grad_output": 0}
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, self.num_gemms, 2 * self.num_gemms
if tp_group is None:
self.tp_size = tp_size
......@@ -492,7 +448,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
# In batchgemm, we use batch=batch_num to launch blas batchgemm
for i in range(self.num_gemms):
for i in range(int(self.num_gemms)):
# Construct weight parameter
self.register_parameter(
f"weight{i}",
......@@ -506,7 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule):
),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i,
fp8_meta_index=_GEMM_WEIGHT + i,
)
# Construct bias parameters if needed
......@@ -515,7 +471,7 @@ class BatchedLinear(TransformerEngineBaseModule):
f"bias{i}",
torch.nn.Parameter(
torch.empty(
self.out_features,
self.out_features * self.batch_num,
device=device,
dtype=params_dtype,
),
......@@ -529,7 +485,11 @@ class BatchedLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=device == "meta")
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=(device == "meta"))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......@@ -543,7 +503,7 @@ class BatchedLinear(TransformerEngineBaseModule):
if not defer_init:
# Set parallelism attributes for linear weights
for i in range(self.num_gemms):
for i in range(int(self.num_gemms)):
set_tensor_model_parallel_attributes(
tensor=getattr(self, f"weight{i}"),
is_parallel=True,
......@@ -553,15 +513,15 @@ class BatchedLinear(TransformerEngineBaseModule):
# Set parallelism attributes for linear biases
if self.use_bias:
for bias in self.bias_names:
for i in range(self.num_gemms):
if self.parallel_mode == "row":
setattr(
getattr(self, bias),
getattr(self, f"bias{i}"),
"sequence_parallel",
self.sequence_parallel,
)
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)
set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1)
@no_torch_dynamo()
def forward(
......@@ -593,57 +553,33 @@ class BatchedLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
assert not isinstance(
inp, Float8Tensor
), "BatchedLinear doesn't support input tensor in FP8."
m_splits_batch_gemm = [x * self.batch_num for x in m_splits[0:int(self.num_gemms)]]
assert len(m_splits_batch_gemm) == self.num_gemms, "Number of splits should match number of GEMMs."
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp=inp, num_gemms=self.num_gemms) as inp:
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
weight_tensors = [getattr(self, f"weight{i}") for i in range(int(self.num_gemms))]
bias_tensors = [getattr(self, f"bias{i}") for i in range(int(self.num_gemms))]
if not self.fp8:
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
w.from_float8() if isinstance(w, Float8Tensor) else w for w in weight_tensors
]
input_quantizers, weight_quantizers, output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
[None] * self.num_gemms,
)
grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["input"] + i]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
input_quantizers[i].internal = True
weight_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["weight"] + i]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][self._offsets["input"] + i]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
weight_tensors_fp8 = [None] * int(self.num_gemms)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
linear_fn = _BatchedLinear.apply
linear_fn = _BatchLinear.apply
args = []
else:
linear_fn = _BatchedLinear.forward
linear_fn = _BatchLinear.forward
args = [None]
args += (
inp,
......@@ -652,18 +588,18 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_output_quantizers,
self.fp8_meta,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
self,
skip_fp8_weight_update,
*weight_tensors,
*weight_tensors_fp8,
*bias_tensors,
)
out = linear_fn(*args)
......@@ -678,7 +614,6 @@ class BatchedLinear(TransformerEngineBaseModule):
)
]
).view(out_shape)
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment