Unverified Commit 4286cc5e authored by tc-mb's avatar tc-mb Committed by GitHub
Browse files

fix(minicpmv): fix audio inference by handling meta device in init_re… (#36751)


Signed-off-by: default avatarcaitianchi <caitianchi@modelbest.cn>
parent 545d18d8
......@@ -1453,10 +1453,11 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config,
prefix=prefix,
)
return resampler.to(
device=current_platform.device_type, dtype=torch.get_default_dtype()
)
target_device = current_platform.device_type
target_dtype = torch.get_default_dtype()
if any(p.is_meta for p in resampler.parameters()):
return resampler.to_empty(device=target_device).to(dtype=target_dtype)
return resampler.to(device=target_device, dtype=target_dtype)
def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
......@@ -1649,10 +1650,11 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config,
prefix=prefix,
)
return resampler.to(
device=current_platform.device_type, dtype=torch.get_default_dtype()
)
target_device = current_platform.device_type
target_dtype = torch.get_default_dtype()
if any(p.is_meta for p in resampler.parameters()):
return resampler.to_empty(device=target_device).to(dtype=target_dtype)
return resampler.to(device=target_device, dtype=target_dtype)
def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment