Unverified Commit 5fdd7bb9 authored by Jianbin Chang's avatar Jianbin Chang Committed by GitHub
Browse files

[PyTorch] check and try to generate fp8 weight transpose cache before dgrad backward (#1648)



* Add fp8 weight transpose cache check in backward, and regenerated it if it does not exist
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* Properly handle fsdp shard model weight input.
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* move Float8Tensor to QuantizedTensor in cast_master_weights_to_fp8 UT
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* handle Float8TensorBase issue
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* fix bug in activation recompute
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 48f3ca90
......@@ -243,10 +243,10 @@ class MiniFSDP:
# Flatten the weights and pad to align with world size
raw_data_list = [
_get_raw_data(w).view(-1) if isinstance(w, Float8Tensor) else w.view(-1)
_get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1)
for w in weights
]
if isinstance(weights[0], Float8Tensor):
if isinstance(weights[0], QuantizedTensor):
raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
else:
raw_data_list = [w.view(-1) for w in weights]
......@@ -282,7 +282,7 @@ class MiniFSDP:
self.weight_indices.append((None, None))
self.shard_indices.append((None, None))
if isinstance(weights[idx], Float8Tensor):
if isinstance(weights[idx], QuantizedTensor):
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
......@@ -378,19 +378,13 @@ class MiniFSDP:
master_weight -= grad * self.lr
# Step 3: Cast master weights to FP8 or BF16 precision
if isinstance(self.weights[0], Float8Tensor):
if isinstance(self.weights[0], QuantizedTensor):
local_weights = []
for model_weight, local_weight in zip(self.weights, self.local_weights):
for local_weight in self.local_weights:
if local_weight is None:
local_weights.append(None)
continue
quantizer = model_weight._get_quantizer()
if isinstance(quantizer, Float8CurrentScalingQuantizer):
local_weight = quantizer.create_tensor_from_data(
local_weight.view(-1),
model_weight.dtype,
)
local_weights.append(local_weight)
cast_master_weights_to_fp8(
......
......@@ -173,7 +173,7 @@ class _GroupedLinear(torch.autograd.Function):
weight_quantizers[i].calibrate(weights[i])
if is_grad_enabled:
ctx.weight_quantizers = weight_quantizers
ctx.weights_shape_1 = weights[0].shape[1]
# TODO: update after #1638 is merged. # pylint: disable=fixme
......@@ -294,6 +294,12 @@ class _GroupedLinear(torch.autograd.Function):
device=ctx.device,
)
for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensor):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
general_grouped_gemm(
weights,
grad_output,
......
......@@ -323,6 +323,7 @@ class _LayerNormLinear(torch.autograd.Function):
clear_tensor_data(ln_out, ln_out_total)
if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
......@@ -651,6 +652,11 @@ class _LayerNormLinear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor):
weight.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
)
dgrad, *_ = general_gemm(
weight,
grad_output,
......
......@@ -478,6 +478,8 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
)
ctx.fc1_weight_quantizer = fc1_weight_quantizer
ctx.fc2_weight_quantizer = fc2_weight_quantizer
if not fc1_weight.requires_grad:
if not return_layernorm_output:
clear_tensor_data(ln_out)
......@@ -749,6 +751,11 @@ class _LayerNormMLP(torch.autograd.Function):
)
# FC2 DGRAD; Unconditional
if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor):
ctx.fc2_weight.update_usage(
rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage,
)
gemm_output, *_ = general_gemm(
fc2_weight,
grad_output,
......@@ -895,6 +902,13 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None)
# FC1 DGRAD: Unconditional
if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensor
):
ctx.fc1_weight.update_usage(
rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage,
)
fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm(
fc1_weight,
dact,
......
......@@ -277,6 +277,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.gemm")
if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
saved_inputmat = None
ctx.backward_input_needs_gather = (
......@@ -574,6 +575,12 @@ class _Linear(torch.autograd.Function):
recipe.fp8_gemm_dgrad.use_split_accumulator
)
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor):
weight_fp8.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
)
dgrad, *_, rs_out = general_gemm(
weight_fp8,
grad_output,
......
......@@ -305,4 +305,11 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
amax=torch.Tensor(),
fp8_dtype=model_weight._fp8_dtype,
)
if use_fsdp_shard_model_weights and not isinstance(model_weight_fragment, Float8Tensor):
# NOTE: The fsdp shard model weight may be a unit8 tensor instead of
# a float8 tensor. We should handle this situation properly.
model_weight_fragment = quantizer.create_tensor_from_data(
model_weight_fragment.view(-1),
model_weight.dtype,
)
quantizer.update_quantized(master_weight, model_weight_fragment)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment