Unverified Commit fdc5df6f authored by Zhe Zhang's avatar Zhe Zhang Committed by GitHub
Browse files

use device param in load_model method (#13037)

parent 3b05cd45
...@@ -1107,7 +1107,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1107,7 +1107,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler(self.device) as m:
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment