Unverified Commit 78648758 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix qwen2.5-vl overflow issue (#13968)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 1dd422b6
...@@ -47,7 +47,7 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, ...@@ -47,7 +47,7 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser, MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config) _minicpmv_field_config)
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
...@@ -469,13 +469,8 @@ class MiniCPMWhisperEncoderLayer(nn.Module): ...@@ -469,13 +469,8 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
training=self.training) training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and ( if hidden_states.dtype == torch.float16:
torch.isinf(hidden_states).any() hidden_states = cast_overflow_tensors(hidden_states)
or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
outputs = (hidden_states, ) outputs = (hidden_states, )
......
...@@ -63,7 +63,7 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP ...@@ -63,7 +63,7 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision) apply_rotary_pos_emb_vision)
from .utils import (AutoWeightsLoader, WeightsMapper, from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import get_vit_attn_backend from .vision import get_vit_attn_backend
...@@ -641,6 +641,11 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -641,6 +641,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens=cu_seqlens_now, cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb) rotary_pos_emb=rotary_pos_emb)
# For Qwen2.5-VL-3B, float16 will overflow at last block
# for long visual tokens sequences.
if hidden_states.dtype == torch.float16:
hidden_states = cast_overflow_tensors(hidden_states)
# adapter # adapter
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index) reverse_indices = torch.argsort(window_index)
......
...@@ -641,3 +641,13 @@ def extract_layer_index(layer_name: str) -> int: ...@@ -641,3 +641,13 @@ def extract_layer_index(layer_name: str) -> int:
assert len(int_vals) == 1, (f"layer name {layer_name} should" assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer") " only contain one integer")
return int_vals[0] return int_vals[0]
def cast_overflow_tensors(
tensors: torch.Tensor,
offset: float = 1000,
) -> torch.Tensor:
if tensors.isinf().any() or tensors.isnan().any():
clamp_value = torch.finfo(tensors.dtype).max - offset
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
return tensors
...@@ -35,7 +35,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo, ...@@ -35,7 +35,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription from .interfaces import SupportsMultiModal, SupportsTranscription
from .utils import AutoWeightsLoader, WeightsMapper, make_layers from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -285,11 +286,7 @@ class WhisperEncoderLayer(nn.Module): ...@@ -285,11 +286,7 @@ class WhisperEncoderLayer(nn.Module):
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if hidden_states.isinf().any() or hidden_states.isnan().any(): hidden_states = cast_overflow_tensors(hidden_states)
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment