Unverified Commit 08b1195e authored by whx's avatar whx Committed by GitHub
Browse files

[PluggableLayer][2/N] Apply PluggableLayer to linear layers (#33152)


Signed-off-by: default avatarwhx-sjtu <2952154980@qq.com>
parent 3bba2edb
......@@ -17,7 +17,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
......@@ -239,7 +239,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
class LinearBase(CustomOp):
class LinearBase(PluggableLayer):
"""Base linear layer.
Args:
......@@ -294,7 +294,7 @@ class LinearBase(CustomOp):
# --8<-- [start:replicated_linear]
@CustomOp.register("replicated_linear")
@PluggableLayer.register("replicated_linear")
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
......@@ -414,7 +414,7 @@ class ReplicatedLinear(LinearBase):
# --8<-- [start:column_parallel_linear]
@CustomOp.register("column_parallel_linear")
@PluggableLayer.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
......@@ -1273,7 +1273,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# --8<-- [start:row_parallel_linear]
@CustomOp.register("row_parallel_linear")
@PluggableLayer.register("row_parallel_linear")
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment