"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "978f1d72963f161654188b9ec3658e99d1e22dba"
Unverified Commit a7eeb28b authored by Li Tao's avatar Li Tao Committed by GitHub
Browse files

[PyTorch] Support TP Overlap in Per-Tensor Current Scaling Recipe (#1554)



* support tp-comm-overlap in Current Scaling recipe
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* clean
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* fix test recipe argument to generalize to MXFP8
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Reduce duplicated transpose in certain cases
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use per_tensor_scaling() to judge DS or CS
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* minor fixes
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* change comment description
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* add multi-layer unit test for tp overlap
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* support test case that run for several times
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* avoid save ub tensor in prepare_for_saving
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* fix
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* switch to a simple fix
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* formatting
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* simply test cases; avoid additional clone()
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* fall back to get_buffer in layernormmlp
Signed-off-by: default avatarLi Tao <lit@nvidia.com>

* use 2 layers for fp8 tpoverlap multi-layer test for better tolerance, limit max gpus for test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarLi Tao <lit@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
parent 37339478
......@@ -17,13 +17,25 @@ import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class multi_module_model(torch.nn.Module):
def __init__(self, module, num_layers, *args, **kwargs):
super().__init__()
self.num_layers = num_layers
self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def _te_layer_argtype(name):
te_layers = [
te.Linear,
......@@ -40,10 +52,12 @@ def _te_layer_argtype(name):
return layer_map[name.lower()]
def _get_layer_args(config, tp_group, tp_size, reference=False):
def _get_layer_args(config, tp_group, tp_size, num_layers, reference=False):
hidden_size = config.num_heads * config.head_dim
ffn_hidden_size = 4 * hidden_size
qkv_size = 3 * hidden_size
if num_layers > 1 and config.layer_type != te.TransformerLayer:
raise ValueError("Stacked layers are only supported for te.TransformerLayer!")
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
......@@ -106,6 +120,9 @@ def _parse_args(argv=None, namespace=None):
description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers."
)
parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP)
parser.add_argument(
"--num-layers", type=int, default=1, help="Number of identical layers to stack."
)
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.")
parser.add_argument(
......@@ -142,6 +159,13 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--quantization",
type=str.lower,
default="none",
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"],
help="Quantization recipe",
)
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
......@@ -341,7 +365,9 @@ def _train(opts):
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Initialize the Transformer Engine layer with overlap
args, kwargs, input_shape = _get_layer_args(opts, nccl_world, opts.tp)
args, kwargs, input_shape = _get_layer_args(
opts, nccl_world, opts.tp, num_layers=opts.num_layers
)
# Intialize userbuffers
ub_cfgs = None
if opts.overlap_rs_dgrad:
......@@ -359,7 +385,7 @@ def _train(opts):
)
with te.fp8_model_init(enabled=opts.fp8_init):
test_model = opts.layer_type(*args, **kwargs)
test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs)
dist_print("Initialized test model...", debug=True)
if WORLD_RANK == 0:
pprint.pprint(kwargs)
......@@ -367,9 +393,11 @@ def _train(opts):
dist.barrier()
# Initialize the reference model and copy all parameters
ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, opts.tp, reference=True)
ref_args, ref_kwargs, _ = _get_layer_args(
opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True
)
with te.fp8_model_init(enabled=opts.fp8_init):
ref_model = opts.layer_type(*ref_args, **ref_kwargs)
ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True)
for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()):
with torch.no_grad():
......@@ -379,7 +407,13 @@ def _train(opts):
# Fp8 recipe setup
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
fp8_recipe = None
if opts.quantization == "fp8_delayed_scaling":
fp8_recipe = DelayedScaling(
fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max"
)
elif opts.quantization == "fp8_current_scaling":
fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
......
......@@ -30,8 +30,11 @@ TE_LAYERS = [
]
MAX_LAYER_NAME_LENGTH = max([len(layer.__name__) for layer in TE_LAYERS])
# to avoid numerical tolerance issues of doing comm gemm overlap, limit the number of GPUs used
MAX_GPUS_TO_USE = 4
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = torch.cuda.device_count()
NUM_PROCS: int = min(torch.cuda.device_count(), MAX_GPUS_TO_USE)
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
if tex.ubuf_built_with_mpi():
LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python3"]
......@@ -83,7 +86,9 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
raise AssertionError(result.stderr.decode())
def _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8):
def _run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1
):
test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
......@@ -93,6 +98,7 @@ def _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad,
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}",
f"--num-layers={num_layers}",
]
if layer_type in [te.Linear.__name__, te.LayerNormLinear.__name__]:
test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}")
......@@ -104,6 +110,7 @@ def _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad,
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
test_cmd.append(f"--quantization={quantization}")
os.environ["PYTORCH_JIT"] = "0"
os.environ["NVTE_TORCH_COMPILE"] = "0"
......@@ -195,7 +202,65 @@ def test_bulk_overlaps(comm_type, fp8, connections):
_run_gemm_with_overlap(comm_type, True, False, False, fp8)
@pytest.mark.parametrize("fp8", (False, True), ids=[" BF16 ", " FP8 "])
@pytest.mark.parametrize(
"fp8",
(False,),
ids=[
" BF16 ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
[
(te.Linear.__name__, "row", False),
(te.Linear.__name__, "column", False),
(te.Linear.__name__, "column", True),
(te.LayerNormLinear.__name__, "row", False),
(te.LayerNormLinear.__name__, "column", False),
(te.LayerNormLinear.__name__, "column", True),
]
+ list(
zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
[None] * len(TE_LAYERS[2:]) * 2,
[False, True] * len(TE_LAYERS[2:]),
)
),
ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ",
]
+ [
" " + " - ".join(test_name_parts) + " "
for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
)
],
)
def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None)
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
[
......@@ -229,8 +294,99 @@ def test_bulk_overlaps(comm_type, fp8, connections):
)
],
)
def test_layers_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8):
def test_layers_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization)
@pytest.mark.parametrize(
"fp8",
(False,),
ids=[
" BF16 ",
],
)
@pytest.mark.parametrize(
"num_layers",
(2,),
ids=[
" 2 layers ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
list(
zip(
[te.TransformerLayer.__name__ for _ in range(2)],
[None] * 2,
[False, True],
)
),
ids=[
" " + " - ".join(test_name_parts) + " "
for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"],
)
],
)
def test_multi_layer_with_overlap_bf16(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, num_layers
)
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
)
@pytest.mark.parametrize(
"num_layers",
(2,),
ids=[
" 2 layers ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
list(
zip(
[te.TransformerLayer.__name__ for _ in range(2)],
[None] * 2,
[False, True],
)
),
ids=[
" " + " - ".join(test_name_parts) + " "
for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"],
)
],
)
def test_multi_layer_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8)
_run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
)
......@@ -242,8 +242,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if CUDA_VERSION >= 12080
cublasLtMatmulMatrixScale_t scaling_mode;
#endif
if ((is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode))) {
if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
void *A_scale_inverse = param.A_scale_inv;
void *B_scale_inverse = param.B_scale_inv;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
......
......@@ -77,6 +77,10 @@ class Recipe:
"""Whether the given recipe is (per-tensor) current scaling."""
return isinstance(self, Float8CurrentScaling)
def float8_per_tensor_scaling(self):
"""Whether the given recipe is per-tensor scaling."""
return isinstance(self, (DelayedScaling, Float8CurrentScaling))
@dataclass()
class DelayedScaling(Recipe):
......
......@@ -223,9 +223,8 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
//unlike delayed scaling, in current scaling, scale is not known, so scale_inv should be empty buffer
opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
at::Tensor scale_inv = at::empty(scale_inv_torch_shape, opts);
// In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set.
at::Tensor scale_inv = at::reciprocal(scale);
py::object ret;
if (internal) {
......
......@@ -148,12 +148,12 @@ class _LayerNormLinear(torch.autograd.Function):
with_input_all_gather = parallel_mode == "column" and sequence_parallel
if fp8:
if (
any([ub_overlap_ag_fprop, ub_overlap_rs_fprop])
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling"
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
if input_quantizer is None:
......@@ -177,9 +177,19 @@ class _LayerNormLinear(torch.autograd.Function):
columnwise=backward_needs_input,
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if (
fp8
and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad
):
input_quantizer.set_usage(rowwise=True, columnwise=False)
ub_obj_fprop = None
ln_out = None
if ub_overlap_ag_fprop:
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
elif with_quantized_norm:
......@@ -208,6 +218,14 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_return = ln_out if return_layernorm_output else None
nvtx_range_pop(f"{nvtx_label}.norm")
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if ub_overlap_ag_fprop and isinstance(input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out_local = ln_out
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
input_quantizer.quantize(ln_out_local, out=ln_out)
# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
......@@ -371,7 +389,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight,
bias,
ln_weight,
ln_out,
ln_out.clone() if ub_overlap_ag_fprop else ln_out, # avoid saving a UB buffer
mu,
rsigma,
)
......@@ -464,9 +482,10 @@ class _LayerNormLinear(torch.autograd.Function):
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.delayed():
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling"
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors
......@@ -553,7 +572,11 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
if ctx.grad_output_quantizer is not None:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
......
......@@ -190,12 +190,12 @@ class _LayerNormMLP(torch.autograd.Function):
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
if (
any([ub_overlap_ag, ub_overlap_rs])
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
if any([ub_overlap_ag, ub_overlap_rs]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling"
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
activation_func = _act_func(
......@@ -209,7 +209,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# for standard fp8: layernorm output = FP8
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
......@@ -237,9 +237,19 @@ class _LayerNormMLP(torch.autograd.Function):
columnwise=backwards_needs_fc1_input,
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if (
fp8
and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad
):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ub_obj_lnout = None
ln_out = None
if ub_overlap_ag:
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if ub_overlap_ag and not isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True)
elif not with_quantized_norm:
......@@ -263,6 +273,14 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_return = ln_out if return_layernorm_output else None
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if ub_overlap_ag and isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_lnout = get_ub("fc1_fprop")
ln_out_local = ln_out
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True)
fc1_input_quantizer.quantize(ln_out_local, out=ln_out)
# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
ln_out_gathered = False
......@@ -589,9 +607,10 @@ class _LayerNormMLP(torch.autograd.Function):
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.delayed():
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling"
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors
......@@ -658,10 +677,11 @@ class _LayerNormMLP(torch.autograd.Function):
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
if ctx.grad_fc2_output_quantizer is not None:
ctx.grad_fc2_output_quantizer.set_usage(
rowwise=True,
columnwise=True,
)
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=True)
ub_obj_fc2_dgrad = None
if ctx.ub_overlap_ag:
......
......@@ -129,12 +129,12 @@ class _Linear(torch.autograd.Function):
own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
if (
any([ub_overlap_ag_fprop, ub_overlap_rs_fprop])
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling"
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
if input_quantizer is None:
......@@ -150,10 +150,17 @@ class _Linear(torch.autograd.Function):
quantizer=input_quantizer,
)
else:
input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input,
)
if (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad
):
# reduce duplicated transpose in `_fix_gathered_fp8_transpose`
input_quantizer.set_usage(rowwise=True, columnwise=False)
else:
input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input,
)
if not isinstance(inputmat, QuantizedTensor):
inputmat = input_quantizer(inputmat)
own_quantized_input = True
......@@ -364,9 +371,10 @@ class _Linear(torch.autograd.Function):
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.delayed():
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling"
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors
......@@ -445,7 +453,11 @@ class _Linear(torch.autograd.Function):
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
if ctx.grad_output_quantizer is not None:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
......
......@@ -213,7 +213,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_epsilon: float = 0.0,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.empty(1, dtype=torch.float32, device=device)
self.scale = torch.ones(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device)
self.dtype = fp8_dtype
self.with_amax_reduction = with_amax_reduction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment