Commit f9795c8c authored by zhuwenwen's avatar zhuwenwen
Browse files

增加dpsk awq mtp推理的支持

parent 058b32ae
......@@ -27,6 +27,7 @@ from .deepseek_v2 import (DeepseekV2DecoderLayer,
from .interfaces import SupportsPP
from .utils import maybe_prefix
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
class SharedHead(nn.Module):
......@@ -164,6 +165,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
self.quant_method = quant_config.get_name()
os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0'
# The AWQ layer of MTP uses BlockInt8W8A8.
if self.quant_method == "moe_wna16":
vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment