Unverified Commit 6ee081d1 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Add new tp plan styles to the Transformers modelling backend (#40467)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 66cc3fa5
...@@ -94,7 +94,15 @@ def init_on_device_without_buffers(device: torch.device): ...@@ -94,7 +94,15 @@ def init_on_device_without_buffers(device: torch.device):
setattr(torch, torch_function_name, old_torch_function) setattr(torch, torch_function_name, old_torch_function)
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] Style = Literal[
"colwise",
"rowwise",
"replicate",
"colwise_gather_output",
"rowwise_split_input",
"colwise_rep",
"rowwise_rep",
]
def replace_linear_class( def replace_linear_class(
...@@ -120,10 +128,14 @@ def replace_linear_class( ...@@ -120,10 +128,14 @@ def replace_linear_class(
vllm_linear_cls, vllm_linear_kwargs = { vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}), "colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}), "rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}), "replicate": (ReplicatedLinear, {}),
# Transformers v5
"colwise_gather_output": (ColumnParallelLinear, {"gather_output": True}),
"rowwise_split_input": (RowParallelLinear, {"input_is_parallel": False}),
# Transformers v4
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
}.get(style, (ReplicatedLinear, {})) }.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls( return vllm_linear_cls(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment