Unverified Commit 5fdd7bb9 authored by Jianbin Chang's avatar Jianbin Chang Committed by GitHub
Browse files

[PyTorch] check and try to generate fp8 weight transpose cache before dgrad backward (#1648)



* Add fp8 weight transpose cache check in backward, and regenerated it if it does not exist
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* Properly handle fsdp shard model weight input.
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* move Float8Tensor to QuantizedTensor in cast_master_weights_to_fp8 UT
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* handle Float8TensorBase issue
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* fix bug in activation recompute
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarjianbinc <shjwudp@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 48f3ca90
...@@ -243,10 +243,10 @@ class MiniFSDP: ...@@ -243,10 +243,10 @@ class MiniFSDP:
# Flatten the weights and pad to align with world size # Flatten the weights and pad to align with world size
raw_data_list = [ raw_data_list = [
_get_raw_data(w).view(-1) if isinstance(w, Float8Tensor) else w.view(-1) _get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1)
for w in weights for w in weights
] ]
if isinstance(weights[0], Float8Tensor): if isinstance(weights[0], QuantizedTensor):
raw_data_list = [_get_raw_data(w).view(-1) for w in weights] raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
else: else:
raw_data_list = [w.view(-1) for w in weights] raw_data_list = [w.view(-1) for w in weights]
...@@ -282,7 +282,7 @@ class MiniFSDP: ...@@ -282,7 +282,7 @@ class MiniFSDP:
self.weight_indices.append((None, None)) self.weight_indices.append((None, None))
self.shard_indices.append((None, None)) self.shard_indices.append((None, None))
if isinstance(weights[idx], Float8Tensor): if isinstance(weights[idx], QuantizedTensor):
replace_raw_data( replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
) )
...@@ -378,19 +378,13 @@ class MiniFSDP: ...@@ -378,19 +378,13 @@ class MiniFSDP:
master_weight -= grad * self.lr master_weight -= grad * self.lr
# Step 3: Cast master weights to FP8 or BF16 precision # Step 3: Cast master weights to FP8 or BF16 precision
if isinstance(self.weights[0], Float8Tensor): if isinstance(self.weights[0], QuantizedTensor):
local_weights = [] local_weights = []
for model_weight, local_weight in zip(self.weights, self.local_weights): for local_weight in self.local_weights:
if local_weight is None: if local_weight is None:
local_weights.append(None) local_weights.append(None)
continue continue
quantizer = model_weight._get_quantizer()
if isinstance(quantizer, Float8CurrentScalingQuantizer):
local_weight = quantizer.create_tensor_from_data(
local_weight.view(-1),
model_weight.dtype,
)
local_weights.append(local_weight) local_weights.append(local_weight)
cast_master_weights_to_fp8( cast_master_weights_to_fp8(
......
...@@ -173,7 +173,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -173,7 +173,7 @@ class _GroupedLinear(torch.autograd.Function):
weight_quantizers[i].calibrate(weights[i]) weight_quantizers[i].calibrate(weights[i])
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizers = weight_quantizers
ctx.weights_shape_1 = weights[0].shape[1] ctx.weights_shape_1 = weights[0].shape[1]
# TODO: update after #1638 is merged. # pylint: disable=fixme # TODO: update after #1638 is merged. # pylint: disable=fixme
...@@ -294,6 +294,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -294,6 +294,12 @@ class _GroupedLinear(torch.autograd.Function):
device=ctx.device, device=ctx.device,
) )
for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensor):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
general_grouped_gemm( general_grouped_gemm(
weights, weights,
grad_output, grad_output,
......
...@@ -323,6 +323,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -323,6 +323,7 @@ class _LayerNormLinear(torch.autograd.Function):
clear_tensor_data(ln_out, ln_out_total) clear_tensor_data(ln_out, ln_out_total)
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
ctx.ln_out_needs_gather = ( ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel weight.requires_grad and parallel_mode == "column" and sequence_parallel
) )
...@@ -651,6 +652,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -651,6 +652,11 @@ class _LayerNormLinear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_dgrad"): if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor):
weight.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
)
dgrad, *_ = general_gemm( dgrad, *_ = general_gemm(
weight, weight,
grad_output, grad_output,
......
...@@ -478,6 +478,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -478,6 +478,8 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
) )
ctx.fc1_weight_quantizer = fc1_weight_quantizer
ctx.fc2_weight_quantizer = fc2_weight_quantizer
if not fc1_weight.requires_grad: if not fc1_weight.requires_grad:
if not return_layernorm_output: if not return_layernorm_output:
clear_tensor_data(ln_out) clear_tensor_data(ln_out)
...@@ -749,6 +751,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -749,6 +751,11 @@ class _LayerNormMLP(torch.autograd.Function):
) )
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor):
ctx.fc2_weight.update_usage(
rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage,
)
gemm_output, *_ = general_gemm( gemm_output, *_ = general_gemm(
fc2_weight, fc2_weight,
grad_output, grad_output,
...@@ -895,6 +902,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -895,6 +902,13 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None)
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensor
):
ctx.fc1_weight.update_usage(
rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage,
)
fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm(
fc1_weight, fc1_weight,
dact, dact,
......
...@@ -277,6 +277,7 @@ class _Linear(torch.autograd.Function): ...@@ -277,6 +277,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.gemm") nvtx_range_pop(f"{nvtx_label}.gemm")
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
saved_inputmat = None saved_inputmat = None
ctx.backward_input_needs_gather = ( ctx.backward_input_needs_gather = (
...@@ -574,6 +575,12 @@ class _Linear(torch.autograd.Function): ...@@ -574,6 +575,12 @@ class _Linear(torch.autograd.Function):
recipe.fp8_gemm_dgrad.use_split_accumulator recipe.fp8_gemm_dgrad.use_split_accumulator
) )
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor):
weight_fp8.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
)
dgrad, *_, rs_out = general_gemm( dgrad, *_, rs_out = general_gemm(
weight_fp8, weight_fp8,
grad_output, grad_output,
......
...@@ -305,4 +305,11 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo ...@@ -305,4 +305,11 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
amax=torch.Tensor(), amax=torch.Tensor(),
fp8_dtype=model_weight._fp8_dtype, fp8_dtype=model_weight._fp8_dtype,
) )
if use_fsdp_shard_model_weights and not isinstance(model_weight_fragment, Float8Tensor):
# NOTE: The fsdp shard model weight may be a unit8 tensor instead of
# a float8 tensor. We should handle this situation properly.
model_weight_fragment = quantizer.create_tensor_from_data(
model_weight_fragment.view(-1),
model_weight.dtype,
)
quantizer.update_quantized(master_weight, model_weight_fragment) quantizer.update_quantized(master_weight, model_weight_fragment)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment