Unverified Commit fdfd409f authored by Chenyaaang's avatar Chenyaaang Committed by GitHub
Browse files

[TPU][Core]Make load weight exceed hbm error more instructive for customers (#20644)


Signed-off-by: default avatarChenyaaang <chenyangli@google.com>
parent ffbcc9e7
...@@ -1128,6 +1128,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1128,6 +1128,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"vllm.model_executor.layers.vocab_parallel_embedding." "vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank", "get_tensor_model_parallel_rank",
return_value=xm_tp_rank): return_value=xm_tp_rank):
try:
if self.use_spmd: if self.use_spmd:
tpu_loader = TPUModelLoader( tpu_loader = TPUModelLoader(
load_config=self.vllm_config.load_config) load_config=self.vllm_config.load_config)
...@@ -1136,7 +1137,6 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1136,7 +1137,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
model_config=self.vllm_config.model_config, model_config=self.vllm_config.model_config,
mesh=self.mesh) mesh=self.mesh)
else: else:
# model = get_model(vllm_config=self.vllm_config)
model_loader = get_model_loader(self.load_config) model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"): if not hasattr(self, "model"):
logger.info("Loading model from scratch...") logger.info("Loading model from scratch...")
...@@ -1146,8 +1146,15 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1146,8 +1146,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
else: else:
logger.info("Model was already initialized. \ logger.info("Model was already initialized. \
Loading weights inplace...") Loading weights inplace...")
model_loader.load_weights(self.model, model_loader.load_weights(
model_config=self.model_config) self.model, model_config=self.model_config)
except RuntimeError as e:
raise RuntimeError(
f"Unable to load model, a likely reason is the model is "
"too large for the current device's HBM memory. "
"Consider switching to a smaller model "
"or sharding the weights on more chips. "
f"See the detailed error: {e}") from e
if self.lora_config is not None: if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config, model = self.load_lora_model(model, self.model_config,
self.scheduler_config, self.scheduler_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment