Unverified Commit 66fb9b13 authored by Hanming Lu's avatar Hanming Lu Committed by GitHub
Browse files

[ServerArgs] allow --mamba-ssm-dtype extend (#12481)

parent 819fc591
......@@ -162,6 +162,8 @@ MOE_RUNNER_BACKEND_CHOICES = [
"cutlass",
]
MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"]
# Allow external code to add more choices
def add_load_format_choices(choices):
......@@ -200,6 +202,10 @@ def add_radix_eviction_policy_choices(choices):
RADIX_EVICTION_POLICY_CHOICES.extend(choices)
def add_mamba_ssm_dtype_choices(choices):
MAMBA_SSM_DTYPE_CHOICES.extend(choices)
@dataclasses.dataclass
class ServerArgs:
"""
......@@ -2902,7 +2908,7 @@ class ServerArgs:
"--mamba-ssm-dtype",
type=str,
default=ServerArgs.mamba_ssm_dtype,
choices=["float32", "bfloat16"],
choices=MAMBA_SSM_DTYPE_CHOICES,
help="The data type of the SSM states in mamba cache.",
)
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment