Unverified Commit 29e5d102 authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

fix online fp8 for MiniCPM models (#39862)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
parent 235e1f93
...@@ -1050,9 +1050,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1050,9 +1050,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "resampler"), prefix=maybe_prefix(prefix, "resampler"),
) )
self._resampler_moved = False
self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors
def _ensure_resampler_device(self) -> None:
if self._resampler_moved:
return
# Only move device, DO NOT touch dtype (fp8 quant needs its own dtype)
self.resampler.to(current_platform.device_type)
self._resampler_moved = True
def _parse_and_validate_vision_input( def _parse_and_validate_vision_input(
self, self,
modality: str, modality: str,
...@@ -1171,7 +1179,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1171,7 +1179,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) loaded = loader.load_weights(weights)
self._ensure_resampler_device()
return loaded
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
...@@ -1276,9 +1286,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -1276,9 +1286,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
prefix=prefix, prefix=prefix,
) )
return resampler.to( return resampler.to(dtype=torch.get_default_dtype())
device=current_platform.device_type, dtype=torch.get_default_dtype()
)
def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["pixel_values"]
...@@ -1359,9 +1367,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1359,9 +1367,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
prefix=prefix, prefix=prefix,
) )
return resampler.to( return resampler.to(dtype=torch.get_default_dtype())
device=current_platform.device_type, dtype=torch.get_default_dtype()
)
def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["pixel_values"]
...@@ -1452,11 +1458,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1452,11 +1458,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
) )
target_device = current_platform.device_type
target_dtype = torch.get_default_dtype() return resampler.to(dtype=torch.get_default_dtype())
if any(p.is_meta for p in resampler.parameters()):
return resampler.to_empty(device=target_device).to(dtype=target_dtype)
return resampler.to(device=target_device, dtype=target_dtype)
def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["pixel_values"]
...@@ -1491,7 +1494,9 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1491,7 +1494,9 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"]) loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"])
return loader.load_weights(weights) loaded = loader.load_weights(weights)
self._ensure_resampler_device()
return loaded
class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
...@@ -1551,10 +1556,7 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1551,10 +1556,7 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
) )
return resampler.to(dtype=torch.get_default_dtype())
return resampler.to(
device=current_platform.device_type, dtype=torch.get_default_dtype()
)
def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["pixel_values"]
...@@ -1589,7 +1591,9 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1589,7 +1591,9 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"]) loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"])
return loader.load_weights(weights) loaded = loader.load_weights(weights)
self._ensure_resampler_device()
return loaded
class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
...@@ -1649,11 +1653,8 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1649,11 +1653,8 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
) )
target_device = current_platform.device_type
target_dtype = torch.get_default_dtype() return resampler.to(dtype=torch.get_default_dtype())
if any(p.is_meta for p in resampler.parameters()):
return resampler.to_empty(device=target_device).to(dtype=target_dtype)
return resampler.to(device=target_device, dtype=target_dtype)
def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["pixel_values"]
...@@ -1692,7 +1693,9 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1692,7 +1693,9 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"]) loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"])
return loader.load_weights(weights) loaded = loader.load_weights(weights)
self._ensure_resampler_device()
return loaded
_SUPPORT_VERSION = { _SUPPORT_VERSION = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment