Unverified Commit 955b5191 authored by Dipika Sikka's avatar Dipika Sikka Committed by GitHub
Browse files

[Misc] update fp8 to use `vLLMParameter` (#7437)

parent 55d63b12
...@@ -15,3 +15,4 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main ...@@ -15,3 +15,4 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
awq, casperhansen/mixtral-instruct-awq, main awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
\ No newline at end of file
...@@ -22,7 +22,7 @@ logger = init_logger(__name__) ...@@ -22,7 +22,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod" "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod"
] ]
...@@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase): ...@@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight) param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_): def forward(self, input_):
...@@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase): ...@@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase):
def weight_loader_v2(self, param: BasevLLMParameter, def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor): loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_): def forward(self, input_):
......
...@@ -19,9 +19,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -19,9 +19,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped) is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise, all_close_1d, apply_fp8_linear, convert_to_channelwise,
create_per_tensor_scale_param, cutlass_fp8_supported, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale) requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once from vllm.utils import is_hip, print_warning_once
...@@ -137,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -137,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase):
): ):
del input_size, output_size del input_size, output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
...@@ -148,34 +150,41 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -148,34 +150,41 @@ class Fp8LinearMethod(LinearMethodBase):
weight_dtype = (torch.float8_e4m3fn weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype) params_dtype)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition, weight = ModelWeightParameter(data=torch.empty(
dtype=weight_dtype), output_size_per_partition,
requires_grad=False) input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})
# If checkpoint is serialized fp8, load them. # If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading. # Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE # WEIGHT SCALE
scale = create_per_tensor_scale_param(output_partition_sizes, scale = PerTensorScaleParameter(data=torch.empty(
**extra_weight_attrs) len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale) layer.register_parameter("weight_scale", scale)
# INPUT ACTIVATION SCALE # INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static": if self.quant_config.activation_scheme == "static":
scale = create_per_tensor_scale_param(output_partition_sizes, scale = PerTensorScaleParameter(data=torch.empty(
**extra_weight_attrs) len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale) layer.register_parameter("input_scale", scale)
else: else:
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
...@@ -197,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -197,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint is fp8, handle that there are N scales for N # If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module # shards in a fused module
else: else:
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
# If using marlin (w8a16), kernel uses channelwise weights, # If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise. # so extend the weight scales to be channelwise.
if self.use_marlin: if self.use_marlin:
......
...@@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter): ...@@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter):
if isinstance(shard_id, int): if isinstance(shard_id, int):
return shard_id return shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert isinstance(shard_id, str) assert isinstance(shard_id, str)
assert shard_id in self.qkv_idxs assert shard_id in self.qkv_idxs
return self.qkv_idxs[shard_id] return self.qkv_idxs[shard_id]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs):
super().load_row_parallel_weight(*args, **kwargs)
def load_merged_column_weight(self, *args, **kwargs): def load_merged_column_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs) self._load_into_shard_id(*args, **kwargs)
...@@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter): ...@@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
self._load_into_shard_id(*args, **kwargs) self._load_into_shard_id(*args, **kwargs)
def load_column_parallel_weight(self, *args, **kwargs): def load_column_parallel_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs) super().load_row_parallel_weight(*args, **kwargs)
def _load_into_shard_id(self, loaded_weight: torch.Tensor, def _load_into_shard_id(self, loaded_weight: torch.Tensor,
shard_id: Union[str, int], **kwargs): shard_id: Union[str, int], **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment