Commit 9d774fc5 authored by zhangzbb's avatar zhangzbb
Browse files

[Feature] add CustomOp Decorator for UnquantizedLinearMethod and...

[Feature] add CustomOp Decorator for UnquantizedLinearMethod and UnquantizedEmbeddingMethod for vllm_hcu
parent 8680bfdb
...@@ -39,6 +39,7 @@ from vllm.model_executor.parameter import ( ...@@ -39,6 +39,7 @@ from vllm.model_executor.parameter import (
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.model_executor.custom_op import CustomOp
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -178,8 +179,8 @@ class LinearMethodBase(QuantizeMethodBase): ...@@ -178,8 +179,8 @@ class LinearMethodBase(QuantizeMethodBase):
Expects create_weights to have been called before on the layer.""" Expects create_weights to have been called before on the layer."""
raise NotImplementedError raise NotImplementedError
@CustomOp.register("unquantized_linear_method")
class UnquantizedLinearMethod(LinearMethodBase): class UnquantizedLinearMethod(LinearMethodBase, CustomOp):
"""Linear method without quantization.""" """Linear method without quantization."""
def create_weights( def create_weights(
......
...@@ -24,11 +24,12 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm ...@@ -24,11 +24,12 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.model_executor.custom_op import CustomOp
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
@CustomOp.register("unquantized_embedding_method")
class UnquantizedEmbeddingMethod(QuantizeMethodBase): class UnquantizedEmbeddingMethod(QuantizeMethodBase, CustomOp):
"""Unquantized method for embeddings.""" """Unquantized method for embeddings."""
def create_weights( def create_weights(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment