"tests/vscode:/vscode.git/clone" did not exist on "cd68e6e55a2155ffaea681b18e012aafd686dce4"
Unverified Commit 27985c27 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: update model config (#9202)

parent ac474869
...@@ -24,7 +24,7 @@ import tempfile ...@@ -24,7 +24,7 @@ import tempfile
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -2117,11 +2117,25 @@ class ServerArgs: ...@@ -2117,11 +2117,25 @@ class ServerArgs:
model_arch = hf_config.architectures[0] model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]: if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None: if self.attention_backend is None:
if is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton" self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"] supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info(
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
)
assert ( assert (
self.attention_backend in supported_backends self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'" ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
if is_sm100_supported():
self.enable_flashinfer_allreduce_fusion = True
logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
quantization_config = getattr(hf_config, "quantization_config", None) quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = ( is_mxfp4_quant_format = (
quantization_config is not None quantization_config is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment