Unverified Commit d364b9b0 authored by HAI's avatar HAI Committed by GitHub
Browse files

ROCm: update AITER (#5816)

parent 849c83a0
...@@ -38,12 +38,12 @@ jobs: ...@@ -38,12 +38,12 @@ jobs:
else else
DEVICE_FLAG="--device /dev/dri" DEVICE_FLAG="--device /dev/dri"
fi fi
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630 docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \ docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \ -v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \ -w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5.post3-rocm630 ghcr.io/saienduri/sglang-aiter-v0.1.1:428
- name: Install dependencies - name: Install dependencies
run: | run: |
...@@ -82,12 +82,12 @@ jobs: ...@@ -82,12 +82,12 @@ jobs:
else else
DEVICE_FLAG="--device /dev/dri" DEVICE_FLAG="--device /dev/dri"
fi fi
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630 docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \ docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \ -v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \ -w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5.post3-rocm630 ghcr.io/saienduri/sglang-aiter-v0.1.1:428
- name: Install dependencies - name: Install dependencies
run: | run: |
...@@ -120,12 +120,12 @@ jobs: ...@@ -120,12 +120,12 @@ jobs:
else else
DEVICE_FLAG="--device /dev/dri" DEVICE_FLAG="--device /dev/dri"
fi fi
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630 docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \ docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \ -v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \ -w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5.post3-rocm630 ghcr.io/saienduri/sglang-aiter-v0.1.1:428
- name: Install dependencies - name: Install dependencies
run: | run: |
......
...@@ -15,7 +15,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -15,7 +15,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
get_config_file_name, get_config_file_name,
) )
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
def main(model, tp_size, dtype: str, batches): def main(model, tp_size, dtype: str, batches):
......
...@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0" ...@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_COMMIT="testx" ARG AITER_COMMIT="v0.1.1"
RUN git clone ${SGL_REPO} \ RUN git clone ${SGL_REPO} \
&& cd sglang \ && cd sglang \
...@@ -74,7 +74,7 @@ ENV SGLANG_SET_CPU_AFFINITY=1 ...@@ -74,7 +74,7 @@ ENV SGLANG_SET_CPU_AFFINITY=1
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
ENV NCCL_MIN_NCHANNELS=112 ENV NCCL_MIN_NCHANNELS=112
ENV MOE_PADDING=1 ENV SGLANG_MOE_PADDING=1
ENV VLLM_FP8_PADDING=1 ENV VLLM_FP8_PADDING=1
ENV VLLM_FP8_ACT_PADDING=1 ENV VLLM_FP8_ACT_PADDING=1
ENV VLLM_FP8_WEIGHT_PADDING=1 ENV VLLM_FP8_WEIGHT_PADDING=1
......
...@@ -45,7 +45,7 @@ if _is_cuda or _is_hip: ...@@ -45,7 +45,7 @@ if _is_cuda or _is_hip:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool( enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
) )
...@@ -1327,7 +1327,7 @@ def fused_experts_impl( ...@@ -1327,7 +1327,7 @@ def fused_experts_impl(
if ( if (
not (use_fp8_w8a8 or use_int8_w8a8) not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None or block_shape is not None
or (_is_hip and get_bool_env_var("CK_MOE")) or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE"))
): ):
padded_size = 0 padded_size = 0
......
...@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
if torch.cuda.is_available(): if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -30,7 +30,9 @@ import logging ...@@ -30,7 +30,9 @@ import logging
_is_hip = is_hip() _is_hip = is_hip()
if _is_hip: if _is_hip:
from aiter import ck_moe from aiter import ActivationType
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -102,14 +104,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -102,14 +104,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_hip and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data), shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter( layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data), shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -182,21 +184,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -182,21 +184,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
if _is_hip and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
assert not no_combine, "unsupported" assert not no_combine, "unsupported"
return ck_moe( return ck_moe_2stages(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
None, activation=(
None, ActivationType.Silu if activation == "silu" else ActivationType.Gelu
None, ),
None,
32,
None,
activation,
) )
else: else:
return fused_experts( return fused_experts(
...@@ -527,7 +525,7 @@ class FusedMoE(torch.nn.Module): ...@@ -527,7 +525,7 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8 # Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name: if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD) # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0 loaded_weight = loaded_weight * 2.0
# this is needed for compressed-tensors only # this is needed for compressed-tensors only
...@@ -569,7 +567,7 @@ class FusedMoE(torch.nn.Module): ...@@ -569,7 +567,7 @@ class FusedMoE(torch.nn.Module):
quant_method = getattr(param, "quant_method", None) quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD) # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5 loaded_weight = loaded_weight * 0.5
self._load_per_channel_weight_scale( self._load_per_channel_weight_scale(
...@@ -592,7 +590,7 @@ class FusedMoE(torch.nn.Module): ...@@ -592,7 +590,7 @@ class FusedMoE(torch.nn.Module):
) )
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD) # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0 loaded_weight = loaded_weight * 2.0
self._load_per_tensor_weight_scale( self._load_per_tensor_weight_scale(
......
...@@ -72,8 +72,8 @@ _is_hip = is_hip() ...@@ -72,8 +72,8 @@ _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_hip: if _is_hip:
from aiter import ActivationType from aiter import ActivationType, QuantType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4 from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
if not _is_cuda: if not _is_cuda:
...@@ -484,7 +484,7 @@ class Fp8MoEMethod: ...@@ -484,7 +484,7 @@ class Fp8MoEMethod:
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = ( params_dtype = (
torch.uint32 torch.uint32
if get_bool_env_var("USE_INT4_WEIGHT") if get_bool_env_var("SGLANG_INT4_WEIGHT")
else torch.float8_e4m3fn else torch.float8_e4m3fn
) )
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -511,7 +511,7 @@ class Fp8MoEMethod: ...@@ -511,7 +511,7 @@ class Fp8MoEMethod:
) )
# WEIGHTS # WEIGHTS
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
# INT4 MoE weight - INT32 packed # INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -585,7 +585,7 @@ class Fp8MoEMethod: ...@@ -585,7 +585,7 @@ class Fp8MoEMethod:
if ( if (
_is_hip _is_hip
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel ): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter( w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
...@@ -612,7 +612,7 @@ class Fp8MoEMethod: ...@@ -612,7 +612,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
) )
...@@ -644,7 +644,7 @@ class Fp8MoEMethod: ...@@ -644,7 +644,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
self.process_weights_hip_int4(layer) self.process_weights_hip_int4(layer)
return return
...@@ -675,7 +675,7 @@ class Fp8MoEMethod: ...@@ -675,7 +675,7 @@ class Fp8MoEMethod:
) )
layer.w2_input_scale = None layer.w2_input_scale = None
if get_bool_env_var("CK_MOE"): if get_bool_env_var("SGLANG_AITER_MOE"):
# Pre-shuffle weights # Pre-shuffle weights
layer.w13_weight.data = shuffle_weight( layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16) layer.w13_weight.contiguous(), (16, 16)
...@@ -798,17 +798,15 @@ class Fp8MoEMethod: ...@@ -798,17 +798,15 @@ class Fp8MoEMethod:
return return
def process_weights_hip_int4(self, layer: Module): def process_weights_hip_int4(self, layer: Module):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added # TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute) # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation # Weight Permutation
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)), shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter( layer.w2_weight = torch.nn.Parameter(
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)), shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
...@@ -847,23 +845,21 @@ class Fp8MoEMethod: ...@@ -847,23 +845,21 @@ class Fp8MoEMethod:
padding_size, # Avoid circular import padding_size, # Avoid circular import
) )
if get_bool_env_var("CK_MOE"): if get_bool_env_var("SGLANG_AITER_MOE"):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)), shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter( layer.w2_weight = torch.nn.Parameter(
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)), shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
# ROCm (CK_MOE): using column-wise scaling # ROCm (SGLANG_AITER_MOE): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("MOE_PADDING"): elif get_bool_env_var("SGLANG_MOE_PADDING"):
# If ROCm, apply weight padding (min. Mem channel contention) only if set # If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
...@@ -912,15 +908,16 @@ class Fp8MoEMethod: ...@@ -912,15 +908,16 @@ class Fp8MoEMethod:
) )
if _is_hip: if _is_hip:
if get_bool_env_var("USE_INT4_WEIGHT"): if get_bool_env_var("SGLANG_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE") # TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages_win4( return ck_moe_2stages(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1, layer.w13_weight_scale1,
layer.w2_weight_scale1, layer.w2_weight_scale1,
activation=( activation=(
...@@ -930,13 +927,13 @@ class Fp8MoEMethod: ...@@ -930,13 +927,13 @@ class Fp8MoEMethod:
), ),
) )
if get_bool_env_var("CK_MOE"): if get_bool_env_var("SGLANG_AITER_MOE"):
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant: if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being. # TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert ( assert (
activation == "silu" activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE" ), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
return asm_moe( return asm_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -955,6 +952,7 @@ class Fp8MoEMethod: ...@@ -955,6 +952,7 @@ class Fp8MoEMethod:
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
QuantType.per_Token,
layer.w13_weight_scale1, layer.w13_weight_scale1,
layer.w2_weight_scale1, layer.w2_weight_scale1,
activation=( activation=(
......
...@@ -31,7 +31,7 @@ from sglang.srt.utils import ( ...@@ -31,7 +31,7 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_hip and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
from aiter import gemm_a8w8_blockscale from aiter import gemm_a8w8_blockscale
if _is_cuda: if _is_cuda:
...@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear( ...@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm( output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
) )
elif _is_hip and get_bool_env_var("CK_MOE"): elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False input_2d, block_size[1], column_major_scales=False
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment