Unverified Commit c091c0a5 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Improve validation of TP in Transformers backend (#15540)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 1aa162e0
...@@ -229,7 +229,10 @@ class TransformersModel(nn.Module): ...@@ -229,7 +229,10 @@ class TransformersModel(nn.Module):
Apply the model's tensor parallelization plan. Apply the model's tensor parallelization plan.
Currently only supports linear layers. Currently only supports linear layers.
""" """
if self.tp_size > 1 and self.config.base_model_tp_plan is None: if not self.model.supports_tp_plan:
if self.tp_size <= 1:
return
raise ValueError( raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!") f"{type(self.model)} does not support tensor parallel yet!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment