Unverified Commit d2389c12 authored by Vasiliy Kuznetsov's avatar Vasiliy Kuznetsov Committed by GitHub
Browse files

fp8 online quant: split out Fp8OnlineLinearMethod (#32189)

parent 22375f8d
...@@ -133,7 +133,7 @@ def test_kv_cache_model_load_and_run( ...@@ -133,7 +133,7 @@ def test_kv_cache_model_load_and_run(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
) )
def test_load_fp16_model( def test_online_quantization(
vllm_runner, vllm_runner,
kv_cache_dtype: str, kv_cache_dtype: str,
force_marlin: bool, force_marlin: bool,
...@@ -191,6 +191,9 @@ def test_load_fp16_model( ...@@ -191,6 +191,9 @@ def test_load_fp16_model(
llm.apply_model(check_model) llm.apply_model(check_model)
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(outputs[0][1])
@pytest.mark.skipif( @pytest.mark.skipif(
not is_quant_method_supported("fp8"), not is_quant_method_supported("fp8"),
......
...@@ -230,9 +230,14 @@ class Fp8Config(QuantizationConfig): ...@@ -230,9 +230,14 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping, fused_mapping=self.packed_modules_mapping,
): ):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
quant_method = Fp8LinearMethod(self) if not self.is_checkpoint_fp8_serialized:
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) online_method = Fp8OnlineLinearMethod(self)
return quant_method online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return online_method
else:
offline_method = Fp8LinearMethod(self)
offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return offline_method
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
if is_layer_skipped( if is_layer_skipped(
prefix=prefix, prefix=prefix,
...@@ -295,13 +300,8 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -295,13 +300,8 @@ class Fp8LinearMethod(LinearMethodBase):
Supports loading FP8 checkpoints with static weight scale and Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale. dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations: Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support. 1. Only support float8_e4m3fn data type due to the limitation of
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args: Args:
...@@ -388,54 +388,11 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -388,54 +388,11 @@ class Fp8LinearMethod(LinearMethodBase):
self.weight_block_size, self.weight_block_size,
) )
# WEIGHT
if self.quant_config.is_checkpoint_fp8_serialized:
weight = create_fp8_weight_parameter( weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader output_size_per_partition, input_size_per_partition, weight_loader
) )
else:
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True
return res
# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE # WEIGHT SCALE
if not self.block_quant: if not self.block_quant:
scale = create_fp8_scale_parameter( scale = create_fp8_scale_parameter(
...@@ -468,9 +425,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -468,9 +425,6 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("input_scale", scale) layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
size_k_first = True size_k_first = True
input_scale = None input_scale = None
# TODO(rob): refactor block quant into separate class. # TODO(rob): refactor block quant into separate class.
...@@ -488,27 +442,20 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -488,27 +442,20 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
else: else:
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N # If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module # shards in a fused module
else:
weight = layer.weight weight = layer.weight
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so # If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight. # requantize the logical shards as a single weight.
if not self.use_marlin: if not self.use_marlin:
weight, weight_scale, input_scale = ( weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
process_fp8_weight_tensor_strategy(
weight, weight,
weight_scale, weight_scale,
layer.logical_widths, layer.logical_widths,
getattr(layer, "input_scale", None), getattr(layer, "input_scale", None),
) )
)
if self.act_q_static: if self.act_q_static:
assert input_scale is not None assert input_scale is not None
input_scale = input_scale.max() input_scale = input_scale.max()
...@@ -607,6 +554,89 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -607,6 +554,89 @@ class Fp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)
class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
and quantized the weights during loading."""
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHT
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True
return res
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# TODO(future): support block_quant in online quant path
assert not self.block_quant
layer.input_scale = None
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# Update layer with new values.
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
if self.use_marlin:
size_k_first = True
prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
class Fp8MoEMethod(FusedMoEMethodBase): class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8. """MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and Supports loading FP8 checkpoints with static weight scale and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment