Unverified Commit 458e74eb authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Support more parallel styles in Transformers backend TP (#22651)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 65abe111
......@@ -107,10 +107,17 @@ def replace_linear_class(
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls = {
"colwise": ColumnParallelLinear,
"rowwise": RowParallelLinear,
}.get(style, ReplicatedLinear)
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {
"gather_output": True
}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {
"input_is_parallel": False
}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
......@@ -118,6 +125,7 @@ def replace_linear_class(
bias=linear.bias is not None,
quant_config=quant_config,
return_bias=False,
**vllm_linear_kwargs,
)
......@@ -506,7 +514,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Some weight loaders expect linear layers to inherit from vLLM's
# LinearBase class, so we set a default style which causes any
# unspecified linear layers to be replaced with ReplicatedLinear
tp_plan[".*"] = "replicated"
tp_plan[".*"] = "replicate"
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment