"nndet/planning/vscode:/vscode.git/clone" did not exist on "2a80e77d5adc82954a61243e7fba9aafea436b33"
Unverified Commit c1003181 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Set usages for linear op quantizers before forward (#2222)



* Make sure to set usages for linear op quantizers before forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unsupported case for fused dbias+quantize kernel

Hopper does not support dbias + FP8 cast without FP8 transpose.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent aee5a821
...@@ -635,6 +635,204 @@ def _test_linear( ...@@ -635,6 +635,204 @@ def _test_linear(
torch.testing.assert_close(db_test, db_ref, **tols) torch.testing.assert_close(db_test, db_ref, **tols)
def _test_mlp(
*,
bias: bool = True,
hidden_size: int = 32,
local_batch_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str] = None,
quantized_weight: bool = False,
sequence_parallel: bool = False,
) -> None:
"""2-layer MLP
MLP includes GELU activation in order to test op fusions. Model
performs warmup steps in order to test inter-step logic.
"""
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
mlp_size = hidden_size * world_size
batch_size = local_batch_size
if sequence_parallel:
batch_size *= world_size
in_shape = (batch_size, hidden_size)
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w1_ref, w1_test = make_reference_and_test_tensors(
(mlp_size, hidden_size),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
b1_ref, b1_test = None, None
w2_ref, w2_test = make_reference_and_test_tensors(
(hidden_size, mlp_size),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
b2_ref, b2_test = None, None
if bias:
b1_ref, b1_test = make_reference_and_test_tensors(
(mlp_size,),
test_dtype=dtype,
test_device=device,
)
b2_ref, b2_test = make_reference_and_test_tensors(
(world_size, hidden_size),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
y_ref = torch.nn.functional.linear(y_ref, w1_ref)
if bias:
y_ref += b1_ref
y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
y_ref = torch.nn.functional.linear(y_ref, w2_ref)
if bias:
y_ref += b2_ref.sum(dim=0)
y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
local_mlp_size = mlp_size // world_size
local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size)
dx_ref = x_ref.grad
dw1_ref = w1_ref.grad[local_mlp_slice, :]
w1_ref = w1_ref[local_mlp_slice, :]
w1_test = w1_test[local_mlp_slice, :]
dw2_ref = w2_ref.grad[:, local_mlp_slice]
w2_ref = w2_ref[:, local_mlp_slice]
w2_test = w2_test[:, local_mlp_slice]
if bias:
db1_ref = b1_ref.grad[local_mlp_slice]
b1_ref = b1_ref[local_mlp_slice]
b1_test = b1_test[local_mlp_slice]
db2_ref = b2_ref.grad[rank, :]
b2_ref = b2_ref[rank, :]
b2_test = b2_test[rank, :]
else:
db1_ref = None
db2_ref = None
if sequence_parallel:
local_batch_slice = slice(
rank * local_batch_size,
(rank + 1) * local_batch_size,
)
x_ref = x_ref[local_batch_slice, ...]
dx_ref = dx_ref[local_batch_slice, ...]
x_test = x_test[local_batch_slice, ...].clone()
y_ref = y_ref[local_batch_slice, ...]
dy_ref = dy_ref[local_batch_slice, ...]
dy_test = dy_test[local_batch_slice, ...].clone()
x_test.requires_grad_()
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.GELU(),
te_ops.Linear(
hidden_size,
mlp_size,
bias=bias,
device=device,
dtype=dtype,
tensor_parallel_mode="column",
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
),
te_ops.GELU(),
te_ops.Linear(
mlp_size,
hidden_size,
bias=bias,
device=device,
dtype=dtype,
tensor_parallel_mode="row",
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
),
te_ops.GELU(),
)
with torch.no_grad():
model[1].weight.copy_(w1_test)
model[3].weight.copy_(w2_test)
if bias:
model[1].bias.copy_(b1_test)
model[3].bias.copy_(b2_test)
del w1_test, w2_test, b1_test, b2_test
# Warmup steps
for _ in range(3):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
x_test.grad = None
model[1].weight.grad = None
model[3].weight.grad = None
if bias:
model[1].bias.grad = None
model[3].bias.grad = None
# Forward and backward step
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
torch.testing.assert_close(dw1_test, dw1_ref, **tols)
torch.testing.assert_close(dw2_test, dw2_ref, **tols)
if bias:
db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db1_test, db1_ref, **tols)
torch.testing.assert_close(db2_test, db2_ref, **tols)
def _test_fp8_scale_update( def _test_fp8_scale_update(
*, *,
amax_history_len: int = 31, amax_history_len: int = 31,
...@@ -801,16 +999,31 @@ def run_parallel_tests() -> None: ...@@ -801,16 +999,31 @@ def run_parallel_tests() -> None:
for config in itertools.product( for config in itertools.product(
quantization_list, quantization_list,
("column", "row"), ("column", "row"),
(False, True),
): ):
if rank == 0: if rank == 0:
print(f"Running _test_linear with {config=}") print(f"Running _test_linear with {config=}")
quantization, tensor_parallel_mode = config quantization, tensor_parallel_mode, sequence_parallel = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
_test_linear( _test_linear(
bias=True, # bias=False is tested in _test_basic_linear bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype, dtype=dtype,
quantization=quantization, quantization=quantization,
tensor_parallel_mode=tensor_parallel_mode, tensor_parallel_mode=tensor_parallel_mode,
sequence_parallel=sequence_parallel,
)
# MLP
for config in itertools.product(quantization_list, (False, True)):
if rank == 0:
print(f"Running _test_mlp with {config=}")
quantization, sequence_parallel = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
_test_mlp(
bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype,
quantization=quantization,
sequence_parallel=sequence_parallel,
) )
# FP8 scale update # FP8 scale update
......
...@@ -54,10 +54,25 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle ...@@ -54,10 +54,25 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
} }
// Unfused impl if quantizer is not supported // Check if fused kernel is supported
const bool with_fused_dbias_quantize_kernel = bool with_fused_kernel = false;
detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr()); if (detail::IsFloat8Quantizers(quantizer.ptr())) {
if (!with_fused_dbias_quantize_kernel) { auto prop = at::cuda::getCurrentDeviceProperties();
const size_t sm_arch = 10 * prop->major + prop->minor;
if (sm_arch >= 100) {
// Fused kernel for dbias + FP8 cast on SM arch 10.0+
with_fused_kernel = true;
} else if (quantizer_cpp->rowwise_usage && quantizer_cpp->columnwise_usage) {
// Fused kernel for dbias + FP8 cast + FP8 transpose
with_fused_kernel = true;
}
} else if (detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Fused kernel for dbias + MXFP8 quantize
with_fused_kernel = true;
}
// Apply unfused impl if fused kernel is not supported
if (!with_fused_kernel) {
at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0});
quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte); quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte);
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
......
...@@ -322,6 +322,20 @@ class BasicLinear(BasicOperation): ...@@ -322,6 +322,20 @@ class BasicLinear(BasicOperation):
if self.weight.device.type == "meta": if self.weight.device.type == "meta":
self.reset_parameters() self.reset_parameters()
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
super().pre_fuser_forward(requires_grad=requires_grad)
if FP8GlobalStateManager.is_fp8_enabled():
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
weight_requires_grad = requires_grad and self.weight.requires_grad
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe) super().reset_recipe_state(recipe=recipe)
...@@ -352,6 +366,35 @@ class BasicLinear(BasicOperation): ...@@ -352,6 +366,35 @@ class BasicLinear(BasicOperation):
and not getattr(self, "_with_quantized_weight", False) and not getattr(self, "_with_quantized_weight", False)
) )
# Recipe-specific configuration
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
if recipe is not None:
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon
if getattr(self, "sequence_parallel", False):
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
if tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
elif tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
if recipe.nvfp4():
if getattr(self, "sequence_parallel", False):
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
if tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
elif tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
@staticmethod @staticmethod
def _functional_forward( def _functional_forward(
input: torch.Tensor, # pylint: disable=redefined-builtin input: torch.Tensor, # pylint: disable=redefined-builtin
...@@ -731,7 +774,7 @@ class BasicLinear(BasicOperation): ...@@ -731,7 +774,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute: if with_quantized_compute:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(columnwise=True) input_quantizer.set_usage(rowwise=False, columnwise=True)
if with_x_all_gather: if with_x_all_gather:
x, x_async = gather_along_first_dim( x, x_async = gather_along_first_dim(
x_local, x_local,
...@@ -912,42 +955,13 @@ class BasicLinear(BasicOperation): ...@@ -912,42 +955,13 @@ class BasicLinear(BasicOperation):
input_requires_grad = ctx.requires_grad input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata # Quantizers
input_quantizer = self.get_quantizer("forward", 0) input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1) weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0) grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
if recipe.nvfp4():
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
......
...@@ -85,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -85,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_requires_grad = linear_op_ctx.requires_grad input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # Quantizers
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer output_quantizer = next_op_input_quantizer
......
...@@ -79,7 +79,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -79,7 +79,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_requires_grad = linear_op_ctx.requires_grad input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # Quantizers
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None output_quantizer = None
......
...@@ -58,7 +58,7 @@ class ForwardLinearScaleAdd(FusedOperation): ...@@ -58,7 +58,7 @@ class ForwardLinearScaleAdd(FusedOperation):
input_requires_grad = linear_op_ctx.requires_grad input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # Quantizers
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None output_quantizer = None
......
...@@ -472,6 +472,10 @@ class OperationFuser: ...@@ -472,6 +472,10 @@ class OperationFuser:
# Attempt to fuse operations if neccesary # Attempt to fuse operations if neccesary
self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs) self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs)
# Initialization before forward
for idx, op in enumerate(self._basic_ops):
op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward)
# Fuser forward pass # Fuser forward pass
if is_grad_enabled: if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply forward_func = _OperationFuserAutogradFunction.apply
......
...@@ -65,6 +65,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -65,6 +65,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def pre_first_fuser_forward(self) -> None: def pre_first_fuser_forward(self) -> None:
"""Preprocessing before first fuser forward pass""" """Preprocessing before first fuser forward pass"""
def pre_fuser_forward(
self,
*,
requires_grad: bool, # pylint: disable=unused-argument
) -> None:
"""Preprocessing before fuser forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]: def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor""" """Get builder class for quantized input tensor"""
...@@ -710,6 +717,10 @@ class FusedOperation(FusibleOperation): ...@@ -710,6 +717,10 @@ class FusedOperation(FusibleOperation):
for op in self.basic_ops: for op in self.basic_ops:
op.pre_first_fuser_forward() op.pre_first_fuser_forward()
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
for op in self.basic_ops:
op.pre_fuser_forward(requires_grad=requires_grad)
def forward( def forward(
self, self,
input: torch.Tensor, # pylint: disable=redefined-builtin input: torch.Tensor, # pylint: disable=redefined-builtin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment