Unverified Commit 82605747 authored by Yueyang Pan's avatar Yueyang Pan Committed by GitHub
Browse files

fix: fp8 quantization failure of qwen 2.5 VL 7B model (#10112)


Signed-off-by: default avatarPanJason <pyyjason@gmail.com>
parent 37f3325b
...@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import ( ...@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
_ColumnvLLMParameter, _ColumnvLLMParameter,
) )
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.utils import pad_or_narrow_weight
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -625,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -625,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# bitsandbytes loads the weights of the specific portion # bitsandbytes loads the weights of the specific portion
# no need to narrow here # no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights: if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow( # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
output_dim, start_idx, shard_size end_idx = start_idx + shard_size
) if end_idx > loaded_weight.shape[output_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, output_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)
# Special case for AQLM codebooks. # Special case for AQLM codebooks.
elif is_metadata: elif is_metadata:
...@@ -1302,7 +1310,16 @@ class RowParallelLinear(LinearBase): ...@@ -1302,7 +1310,16 @@ class RowParallelLinear(LinearBase):
shard_size, shard_size,
) )
else: else:
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[input_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, input_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
input_dim, start_idx, shard_size
)
# Special case for loading scales off disk, which often do not # Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8). # have a shape (such as in the case of AutoFP8).
......
...@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union ...@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from sglang.srt.layers.utils import pad_or_narrow_weight
from sglang.srt.utils import is_cpu from sglang.srt.utils import is_cpu
__all__ = [ __all__ = [
...@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
) )
else: else:
if not use_presharded_weights: if not use_presharded_weights:
loaded_weight = loaded_weight.narrow( # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
self.output_dim, tp_rank * shard_size, shard_size start_idx = tp_rank * shard_size
) end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[self.output_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, self.output_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
self.output_dim, start_idx, shard_size
)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter):
return return
else: else:
loaded_weight = loaded_weight.narrow( # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
self.input_dim, tp_rank * shard_size, shard_size start_idx = tp_rank * shard_size
) end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[self.input_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, self.input_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
self.input_dim, start_idx, shard_size
)
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
......
...@@ -393,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -393,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x.dtype, x.dtype,
True, # is_vnni True, # is_vnni
) )
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
return int8_scaled_mm( x_q_2d = x_q.view(-1, x_q.shape[-1])
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
output = int8_scaled_mm(
x_q_2d,
layer.weight,
x_scale_2d,
layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
) )
return output.view(output_shape)
class W8A8Int8MoEMethod(FusedMoEMethodBase): class W8A8Int8MoEMethod(FusedMoEMethodBase):
"""MoE method for INT8. """MoE method for INT8.
......
...@@ -15,6 +15,29 @@ def get_layer_id(weight_name): ...@@ -15,6 +15,29 @@ def get_layer_id(weight_name):
return None return None
def pad_or_narrow_weight(
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
) -> torch.Tensor:
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
if valid_size > 0:
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
pad_shape = list(loaded_weight.shape)
pad_shape[input_dim] = shard_size - valid_size
pad = torch.zeros(
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
)
return torch.cat([loaded_slice, pad], dim=input_dim)
# All padding
pad_shape = list(loaded_weight.shape)
pad_shape[input_dim] = shard_size
return torch.zeros(
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
)
class PPMissingLayer(torch.nn.Identity): class PPMissingLayer(torch.nn.Identity):
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1 # https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
......
...@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.fullatt_block_indexes = vision_config.fullatt_block_indexes self.fullatt_block_indexes = vision_config.fullatt_block_indexes
self.window_size = vision_config.window_size self.window_size = vision_config.window_size
self.patch_size = vision_config.patch_size self.patch_size = vision_config.patch_size
mlp_hidden_size: int = vision_config.intermediate_size mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
self.patch_embed = Qwen2_5_VisionPatchEmbed( self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size, patch_size=patch_size,
temporal_patch_size=temporal_patch_size, temporal_patch_size=temporal_patch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment