"vscode:/vscode.git/clone" did not exist on "a27f87da3429eafbbaf4f5372bf458ae40e30618"
Unverified Commit 0d402d26 authored by Vasiliy Kuznetsov's avatar Vasiliy Kuznetsov Committed by GitHub
Browse files

online fp8 quant with streaming weight post-processing (#29196)


Signed-off-by: default avatarvasiliy <vasiliy@fb.com>
parent d1b5e7af
...@@ -465,6 +465,30 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -465,6 +465,30 @@ class Fp8LinearMethod(LinearMethodBase):
output_size_per_partition, input_size_per_partition, weight_loader output_size_per_partition, input_size_per_partition, weight_loader
) )
else: else:
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True
return res
# For non-serialized checkpoints, use original dtype # For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter( weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
...@@ -474,7 +498,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -474,7 +498,7 @@ class Fp8LinearMethod(LinearMethodBase):
), ),
input_dim=1, input_dim=1,
output_dim=0, output_dim=0,
weight_loader=weight_loader, weight_loader=patched_weight_loader,
) )
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
...@@ -515,6 +539,9 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -515,6 +539,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
size_k_first = True size_k_first = True
input_scale = None input_scale = None
# TODO(rob): refactor block quant into separate class. # TODO(rob): refactor block quant into separate class.
...@@ -738,6 +765,41 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -738,6 +765,41 @@ class Fp8MoEMethod(FusedMoEMethodBase):
f"weight quantization block_k = {block_k}." f"weight quantization block_k = {block_k}."
) )
# if we are doing online quantization, patch the weight
# loaded to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded
if not self.quant_config.is_checkpoint_fp8_serialized:
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True
return res
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -839,6 +901,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -839,6 +901,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.rocm_aiter_moe_enabled = False self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Lazy import to avoid importing triton too early. # Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment