"tests/vscode:/vscode.git/clone" did not exist on "d4f0d2bcc49b17eb04423bc03e601e5153442fed"
Unverified Commit 6ee081d1 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Add new tp plan styles to the Transformers modelling backend (#40467)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 66cc3fa5
......@@ -94,7 +94,15 @@ def init_on_device_without_buffers(device: torch.device):
setattr(torch, torch_function_name, old_torch_function)
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
Style = Literal[
"colwise",
"rowwise",
"replicate",
"colwise_gather_output",
"rowwise_split_input",
"colwise_rep",
"rowwise_rep",
]
def replace_linear_class(
......@@ -120,10 +128,14 @@ def replace_linear_class(
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
# Transformers v5
"colwise_gather_output": (ColumnParallelLinear, {"gather_output": True}),
"rowwise_split_input": (RowParallelLinear, {"input_is_parallel": False}),
# Transformers v4
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment