Unverified Commit b819381f authored by HAI's avatar HAI Committed by GitHub
Browse files

AITER backend extension and workload optimizations (#6838)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarHubert Lu <Hubert.Lu@amd.com>
parent 562f279a
...@@ -72,7 +72,7 @@ jobs: ...@@ -72,7 +72,7 @@ jobs:
- name: Evaluate accuracy (TP=2) - name: Evaluate accuracy (TP=2)
timeout-minutes: 30 timeout-minutes: 30
run: | run: |
bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py bash scripts/amd_ci_exec.sh -e SGLANG_USE_AITER=0 python3 test_moe_eval_accuracy_large.py
mla-test-1-gpu-amd: mla-test-1-gpu-amd:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
......
...@@ -53,7 +53,7 @@ SGLang supports various environment variables that can be used to configure its ...@@ -53,7 +53,7 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value | | Environment Variable | Description | Default Value |
| --- | --- | --- | | --- | --- | --- |
| `SGLANG_AITER_MOE` | Use AITER MOE implementation | `false` | | `SGLANG_USE_AITER` | Use AITER optimize implementation | `false` |
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` | | `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` | | `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` | | `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
......
...@@ -20,10 +20,11 @@ import torch ...@@ -20,10 +20,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
...@@ -33,7 +34,10 @@ if _is_cuda: ...@@ -33,7 +34,10 @@ if _is_cuda:
rmsnorm, rmsnorm,
) )
if _is_hip: if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
elif _is_hip:
from vllm._custom_ops import fused_add_rms_norm, rms_norm from vllm._custom_ops import fused_add_rms_norm, rms_norm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -48,6 +52,8 @@ class RMSNorm(CustomOp): ...@@ -48,6 +52,8 @@ class RMSNorm(CustomOp):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
if _use_aiter:
self._forward_method = self.forward_aiter
def forward_cuda( def forward_cuda(
self, self,
...@@ -60,6 +66,25 @@ class RMSNorm(CustomOp): ...@@ -60,6 +66,25 @@ class RMSNorm(CustomOp):
out = rmsnorm(x, self.weight.data, self.variance_epsilon) out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out return out
def forward_aiter(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
residual_out = torch.empty_like(x)
output = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
residual,
residual_out,
self.weight.data,
self.variance_epsilon,
)
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
def forward_hip( def forward_hip(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -1332,7 +1332,7 @@ def fused_experts_impl( ...@@ -1332,7 +1332,7 @@ def fused_experts_impl(
if ( if (
not (use_fp8_w8a8 or use_int8_w8a8) not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None or block_shape is not None
or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE")) or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
): ):
padded_size = 0 padded_size = 0
......
...@@ -28,8 +28,9 @@ else: ...@@ -28,8 +28,9 @@ else:
import logging import logging
_is_hip = is_hip() _is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip: if _use_aiter:
from aiter import ActivationType from aiter import ActivationType
from aiter.fused_moe_bf16_asm import ck_moe_2stages from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
...@@ -104,7 +105,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -104,7 +105,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): if _use_aiter:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)), shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
...@@ -188,7 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -188,7 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): if _use_aiter:
assert not no_combine, "unsupported" assert not no_combine, "unsupported"
if apply_router_weight_on_input: if apply_router_weight_on_input:
assert ( assert (
......
...@@ -77,8 +77,8 @@ _is_cuda = is_cuda() ...@@ -77,8 +77,8 @@ _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT") _use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip: if _is_hip:
from aiter import ActivationType, QuantType from aiter import ActivationType, QuantType
...@@ -487,7 +487,7 @@ class Fp8MoEMethod: ...@@ -487,7 +487,7 @@ class Fp8MoEMethod:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.block_quant: if self.block_quant:
block_n, block_k = ( block_n, block_k = (
...@@ -512,7 +512,7 @@ class Fp8MoEMethod: ...@@ -512,7 +512,7 @@ class Fp8MoEMethod:
) )
# WEIGHTS # WEIGHTS
if _is_hip and use_hip_int4: if _is_hip and _use_hip_int4:
# INT4 MoE weight - INT32 packed # INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -641,7 +641,7 @@ class Fp8MoEMethod: ...@@ -641,7 +641,7 @@ class Fp8MoEMethod:
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel if _is_hip: # _use_aiter: TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter( w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
...@@ -668,7 +668,7 @@ class Fp8MoEMethod: ...@@ -668,7 +668,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if _is_hip and use_hip_int4: if _is_hip and _use_hip_int4:
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
) )
...@@ -700,7 +700,7 @@ class Fp8MoEMethod: ...@@ -700,7 +700,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and use_hip_int4: if _is_hip and _use_hip_int4:
self.process_weights_hip_int4(layer) self.process_weights_hip_int4(layer)
return return
...@@ -731,7 +731,7 @@ class Fp8MoEMethod: ...@@ -731,7 +731,7 @@ class Fp8MoEMethod:
) )
layer.w2_input_scale = None layer.w2_input_scale = None
if _is_hip and use_aiter_moe: if _use_aiter:
# Pre-shuffle weights # Pre-shuffle weights
layer.w13_weight.data = shuffle_weight( layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16) layer.w13_weight.contiguous(), (16, 16)
...@@ -853,7 +853,7 @@ class Fp8MoEMethod: ...@@ -853,7 +853,7 @@ class Fp8MoEMethod:
return return
def process_weights_hip_int4(self, layer: Module): def process_weights_hip_int4(self, layer: Module):
# TODO: and use_aiter_moe: add after triton kernel added # TODO: _use_aiter: add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute) # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation # Weight Permutation
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
...@@ -900,7 +900,7 @@ class Fp8MoEMethod: ...@@ -900,7 +900,7 @@ class Fp8MoEMethod:
padding_size, # Avoid circular import padding_size, # Avoid circular import
) )
if use_aiter_moe: if _use_aiter:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)), shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
...@@ -911,7 +911,7 @@ class Fp8MoEMethod: ...@@ -911,7 +911,7 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
# ROCm (use_aiter_moe): using column-wise scaling # ROCm (_use_aiter): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("SGLANG_MOE_PADDING"): elif get_bool_env_var("SGLANG_MOE_PADDING"):
...@@ -1041,8 +1041,8 @@ class Fp8MoEMethod: ...@@ -1041,8 +1041,8 @@ class Fp8MoEMethod:
activation: str = "silu", activation: str = "silu",
no_combine: bool = False, no_combine: bool = False,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if use_hip_int4: if _use_hip_int4:
# TODO: add triton kernel and add check use_aiter_moe # TODO: add triton kernel and add check _use_aiter
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages( return ck_moe_2stages(
x, x,
...@@ -1058,13 +1058,13 @@ class Fp8MoEMethod: ...@@ -1058,13 +1058,13 @@ class Fp8MoEMethod:
), ),
) )
if use_aiter_moe: if _use_aiter:
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant: if self.block_quant:
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being. # TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being.
assert ( assert (
activation == "silu" activation == "silu"
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe" ), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
return asm_moe( return asm_moe(
x, x,
layer.w13_weight, layer.w13_weight,
......
...@@ -38,11 +38,10 @@ _is_hip = is_hip() ...@@ -38,11 +38,10 @@ _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") if _use_aiter:
from aiter import gemm_a8w8_blockscale_CK
if _is_hip and use_aiter_moe:
from aiter import gemm_a8w8_blockscale
if _is_cuda: if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
...@@ -141,7 +140,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable: ...@@ -141,7 +140,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
return flashinfer_gemm_w8a8_block_fp8_linear return flashinfer_gemm_w8a8_block_fp8_linear
elif CUTLASS_BLOCK_FP8_SUPPORTED: elif CUTLASS_BLOCK_FP8_SUPPORTED:
return cutlass_w8a8_block_fp8_linear_with_fallback return cutlass_w8a8_block_fp8_linear_with_fallback
elif _is_hip and use_aiter_moe: elif _use_aiter:
return aiter_w8a8_block_fp8_linear return aiter_w8a8_block_fp8_linear
elif _ENABLE_JIT_DEEPGEMM: elif _ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback return deepgemm_w8a8_block_fp8_linear_with_fallback
...@@ -268,12 +267,9 @@ def aiter_w8a8_block_fp8_linear( ...@@ -268,12 +267,9 @@ def aiter_w8a8_block_fp8_linear(
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False input_2d, block_size[1], column_major_scales=False
) )
output = torch.zeros( output = gemm_a8w8_blockscale_CK(
[q_input.shape[0], weight.shape[0]], q_input, weight, x_scale, weight_scale, dtype=input.dtype
dtype=input_2d.dtype,
device=q_input.device,
) )
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
if bias is not None: if bias is not None:
output += bias output += bias
......
...@@ -355,6 +355,15 @@ class ModelRunner: ...@@ -355,6 +355,15 @@ class ModelRunner:
# MLA architecture # MLA architecture
if is_hopper_with_cuda_12_3(): if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3" server_args.attention_backend = "fa3"
elif _is_hip:
head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
if (
head_num == 128 or head_num == 16
) and self.spec_algorithm.is_none():
server_args.attention_backend = "aiter"
else:
server_args.attention_backend = "triton"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
logger.info( logger.info(
...@@ -363,6 +372,7 @@ class ModelRunner: ...@@ -363,6 +372,7 @@ class ModelRunner:
elif self.use_mla_backend: elif self.use_mla_backend:
if server_args.device != "cpu": if server_args.device != "cpu":
if server_args.attention_backend in [ if server_args.attention_backend in [
"aiter",
"flashinfer", "flashinfer",
"fa3", "fa3",
"triton", "triton",
......
...@@ -105,6 +105,7 @@ from sglang.srt.utils import ( ...@@ -105,6 +105,7 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
...@@ -120,6 +121,9 @@ if _is_hip: ...@@ -120,6 +121,9 @@ if _is_hip:
decode_attention_fwd_grouped_rope, decode_attention_fwd_grouped_rope,
) )
if _use_aiter:
from aiter.rotary_embedding import get_rope
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
self.alt_stream = alt_stream self.alt_stream = alt_stream
self.attn_mha.kv_b_proj = None
self.w_kc = None self.w_kc = None
self.w_vc = None self.w_vc = None
...@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
return _dispatch_mla_subtype() return _dispatch_mla_subtype()
elif self.attention_backend == "aiter":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else: else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode # Triton: Use normal computation for prefill and use weight absorption for extend/decode
if ( if (
...@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
): ):
if self.attn_mha.kv_b_proj is None:
self.attn_mha.kv_b_proj = self.kv_b_proj
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
assert ( assert (
not self.o_proj.reduce_results not self.o_proj.reduce_results
......
#!/bin/bash #!/bin/bash
set -euo pipefail set -euo pipefail
# Default working directory
WORKDIR="/sglang-checkout/test/srt" WORKDIR="/sglang-checkout/test/srt"
ENV_ARGS=( declare -A ENV_MAP=(
-e SGLANG_AMD_CI=1 [SGLANG_AMD_CI]=1
-e SGLANG_IS_IN_CI=1 [SGLANG_IS_IN_CI]=1
-e SGLANG_AITER_MOE=1 [SGLANG_USE_AITER]=1
) )
# Parse optional -w/--workdir and -e ENV=VAL flags # Parse -w/--workdir and -e ENV=VAL
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case "$1" in case "$1" in
-w|--workdir) -w|--workdir)
...@@ -17,7 +16,8 @@ while [[ $# -gt 0 ]]; do ...@@ -17,7 +16,8 @@ while [[ $# -gt 0 ]]; do
shift 2 shift 2
;; ;;
-e) -e)
ENV_ARGS+=("-e" "$2") IFS="=" read -r key val <<< "$2"
ENV_MAP["$key"]="$val"
shift 2 shift 2
;; ;;
--) --)
...@@ -30,6 +30,12 @@ while [[ $# -gt 0 ]]; do ...@@ -30,6 +30,12 @@ while [[ $# -gt 0 ]]; do
esac esac
done done
# Build final ENV_ARGS
ENV_ARGS=()
for key in "${!ENV_MAP[@]}"; do
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
done
# Run docker exec # Run docker exec
docker exec \ docker exec \
-w "$WORKDIR" \ -w "$WORKDIR" \
......
...@@ -171,7 +171,7 @@ class TestNightlyGsm8KEval(unittest.TestCase): ...@@ -171,7 +171,7 @@ class TestNightlyGsm8KEval(unittest.TestCase):
os.environ["HF_HUB_DISABLE_XET"] = ( os.environ["HF_HUB_DISABLE_XET"] = (
"1" if model in DISABLE_HF_XET_MODELS else "0" "1" if model in DISABLE_HF_XET_MODELS else "0"
) )
os.environ["SGLANG_AITER_MOE"] = ( os.environ["SGLANG_USE_AITER"] = (
"0" if model in TRITON_MOE_MODELS else "1" "0" if model in TRITON_MOE_MODELS else "1"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment