Unverified Commit 20bd2271 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support true on-policy (#12058)

parent 64994980
...@@ -29,6 +29,7 @@ from sglang.srt.distributed import ( ...@@ -29,6 +29,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
is_cpu, is_cpu,
...@@ -59,6 +60,11 @@ logger = logging.getLogger(__name__) ...@@ -59,6 +60,11 @@ logger = logging.getLogger(__name__)
class SiluAndMul(CustomOp): class SiluAndMul(CustomOp):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if get_global_server_args().rl_on_policy_target == "fsdp":
self._forward_method = self.forward_native
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:] return F.silu(x[..., :d]) * x[..., d:]
......
...@@ -73,9 +73,16 @@ class RMSNorm(CustomOp): ...@@ -73,9 +73,16 @@ class RMSNorm(CustomOp):
hidden_size: int, hidden_size: int,
eps: float = 1e-6, eps: float = 1e-6,
var_hidden_size: Optional[int] = None, var_hidden_size: Optional[int] = None,
cast_x_before_out_mul: bool = False,
fp32_residual: bool = False,
weight_dtype: Optional = None,
override_orig_dtype: Optional = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.cast_x_before_out_mul = cast_x_before_out_mul
self.fp32_residual = fp32_residual
self.override_orig_dtype = override_orig_dtype
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
self.variance_epsilon = eps self.variance_epsilon = eps
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.variance_size_override = ( self.variance_size_override = (
...@@ -165,11 +172,14 @@ class RMSNorm(CustomOp): ...@@ -165,11 +172,14 @@ class RMSNorm(CustomOp):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous(): if not x.is_contiguous():
x = x.contiguous() x = x.contiguous()
orig_dtype = x.dtype orig_dtype = self.override_orig_dtype or x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
if residual is not None: if residual is not None:
x = x + residual.to(torch.float32) x = x + residual.to(torch.float32)
residual = x.to(orig_dtype) if self.fp32_residual:
residual = x.clone()
else:
residual = x.to(orig_dtype)
hidden_size = x.shape[-1] hidden_size = x.shape[-1]
if hidden_size != self.hidden_size: if hidden_size != self.hidden_size:
...@@ -191,7 +201,12 @@ class RMSNorm(CustomOp): ...@@ -191,7 +201,12 @@ class RMSNorm(CustomOp):
variance = x_var.pow(2).mean(dim=-1, keepdim=True) variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = (x * self.weight).to(orig_dtype)
if self.cast_x_before_out_mul:
x = self.weight * x.to(orig_dtype)
else:
x = (x * self.weight).to(orig_dtype)
if residual is None: if residual is None:
return x return x
else: else:
......
...@@ -593,6 +593,11 @@ class LogitsProcessor(nn.Module): ...@@ -593,6 +593,11 @@ class LogitsProcessor(nn.Module):
None, # bias None, # bias
True, # is_vnni True, # is_vnni
) )
elif get_global_server_args().rl_on_policy_target == "fsdp":
# Due to tie-weight, we may not be able to change lm_head's weight dtype
logits = torch.matmul(
hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
)
else: else:
logits = torch.matmul( logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
......
...@@ -11,6 +11,7 @@ import triton ...@@ -11,6 +11,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
...@@ -124,18 +125,29 @@ class RotaryEmbedding(CustomOp): ...@@ -124,18 +125,29 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache: torch.Tensor self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
if get_global_server_args().rl_on_policy_target == "fsdp":
self._forward_method = self.forward_native
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to # NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we # use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause # create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours. # a slight numerical difference between the HF implementation and ours.
init_device = (
"cpu" if get_global_server_args().rl_on_policy_target == "fsdp" else None
)
inv_freq = 1.0 / ( inv_freq = 1.0 / (
base base
** ( ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device=init_device
)
/ self.rotary_dim
) )
) )
if get_global_server_args().rl_on_policy_target == "fsdp":
inv_freq = inv_freq.cuda()
return inv_freq return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
......
...@@ -102,6 +102,14 @@ class Sampler(nn.Module): ...@@ -102,6 +102,14 @@ class Sampler(nn.Module):
if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
probs_without_temp_scaling = torch.softmax(logits, dim=-1) probs_without_temp_scaling = torch.softmax(logits, dim=-1)
if get_global_server_args().rl_on_policy_target == "fsdp":
logits_div_temperature = (
logits.bfloat16().div(sampling_info.temperatures).bfloat16()
)
logprobs_via_logsoftmax_kernel = torch.log_softmax(
logits_div_temperature, dim=-1
)
# Post process logits # Post process logits
logits.div_(sampling_info.temperatures) logits.div_(sampling_info.temperatures)
logits[:] = torch.softmax(logits, dim=-1) logits[:] = torch.softmax(logits, dim=-1)
...@@ -148,8 +156,11 @@ class Sampler(nn.Module): ...@@ -148,8 +156,11 @@ class Sampler(nn.Module):
) )
if return_logprob: if return_logprob:
if get_global_server_args().rl_on_policy_target == "fsdp":
logprobs = logprobs_via_logsoftmax_kernel
del logprobs_via_logsoftmax_kernel
# clamp to avoid -inf # clamp to avoid -inf
if SGLANG_RETURN_ORIGINAL_LOGPROB: elif SGLANG_RETURN_ORIGINAL_LOGPROB:
logprobs = torch.log(probs_without_temp_scaling).clamp( logprobs = torch.log(probs_without_temp_scaling).clamp(
min=torch.finfo(probs_without_temp_scaling.dtype).min min=torch.finfo(probs_without_temp_scaling.dtype).min
) )
......
...@@ -49,6 +49,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -49,6 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
Qwen2Config = None Qwen2Config = None
...@@ -89,6 +90,9 @@ class Qwen2MLP(nn.Module): ...@@ -89,6 +90,9 @@ class Qwen2MLP(nn.Module):
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x):
if get_global_server_args().rl_on_policy_target == "fsdp":
x = x.bfloat16()
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
...@@ -275,6 +279,11 @@ class Qwen2Model(nn.Module): ...@@ -275,6 +279,11 @@ class Qwen2Model(nn.Module):
quant_config=quant_config, quant_config=quant_config,
enable_tp=not is_dp_attention_enabled(), enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
params_dtype=(
torch.float32
if get_global_server_args().rl_on_policy_target == "fsdp"
else None
),
) )
else: else:
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
...@@ -295,7 +304,19 @@ class Qwen2Model(nn.Module): ...@@ -295,7 +304,19 @@ class Qwen2Model(nn.Module):
prefix=add_prefix("layers", prefix), prefix=add_prefix("layers", prefix),
) )
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) norm_kwargs = (
dict(
weight_dtype=torch.float32,
cast_x_before_out_mul=True,
override_orig_dtype=torch.float32,
fp32_residual=True,
)
if get_global_server_args().rl_on_policy_target == "fsdp"
else {}
)
self.norm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
)
else: else:
self.norm = PPMissingLayer(return_tuple=True) self.norm = PPMissingLayer(return_tuple=True)
......
...@@ -29,6 +29,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -29,6 +29,7 @@ from sglang.srt.model_loader.weight_utils import (
) )
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
add_prefix, add_prefix,
get_cmo_stream, get_cmo_stream,
...@@ -88,8 +89,16 @@ class Qwen3Attention(nn.Module): ...@@ -88,8 +89,16 @@ class Qwen3Attention(nn.Module):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) norm_kwargs = (
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) dict(
weight_dtype=torch.float32,
cast_x_before_out_mul=True,
)
if get_global_server_args().rl_on_policy_target == "fsdp"
else {}
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
...@@ -158,10 +167,18 @@ class Qwen3Attention(nn.Module): ...@@ -158,10 +167,18 @@ class Qwen3Attention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
if get_global_server_args().rl_on_policy_target == "fsdp":
hidden_states = hidden_states.bfloat16()
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
if get_global_server_args().rl_on_policy_target == "fsdp":
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
attn_output = self.attn(q, k, v, forward_batch) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -204,9 +221,22 @@ class Qwen3DecoderLayer(nn.Module): ...@@ -204,9 +221,22 @@ class Qwen3DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
norm_kwargs = (
dict(
weight_dtype=torch.float32,
cast_x_before_out_mul=True,
override_orig_dtype=torch.float32,
fp32_residual=True,
)
if get_global_server_args().rl_on_policy_target == "fsdp"
else {}
)
self.input_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
) )
self.layer_scatter_modes = LayerScatterModes.init_new( self.layer_scatter_modes = LayerScatterModes.init_new(
......
...@@ -472,6 +472,7 @@ class ServerArgs: ...@@ -472,6 +472,7 @@ class ServerArgs:
enable_return_hidden_states: bool = False enable_return_hidden_states: bool = False
scheduler_recv_interval: int = 1 scheduler_recv_interval: int = 1
numa_node: Optional[List[int]] = None numa_node: Optional[List[int]] = None
rl_on_policy_target: Optional[str] = None
enable_deterministic_inference: bool = False enable_deterministic_inference: bool = False
# Dynamic batch tokenizer # Dynamic batch tokenizer
...@@ -1526,6 +1527,14 @@ class ServerArgs: ...@@ -1526,6 +1527,14 @@ class ServerArgs:
) )
def _handle_deterministic_inference(self): def _handle_deterministic_inference(self):
if self.rl_on_policy_target is not None:
logger.warning(
"Enable deterministic inference because of rl_on_policy_target."
)
self.enable_deterministic_inference = True
# TODO remove this environment variable as a whole
os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1"
if self.enable_deterministic_inference: if self.enable_deterministic_inference:
# Check sampling backend # Check sampling backend
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
...@@ -3300,6 +3309,13 @@ class ServerArgs: ...@@ -3300,6 +3309,13 @@ class ServerArgs:
) )
# For deterministic inference # For deterministic inference
parser.add_argument(
"--rl-on-policy-target",
type=str,
default=ServerArgs.rl_on_policy_target,
choices=["fsdp"],
help="The training system that SGLang needs to match for true on-policy.",
)
parser.add_argument( parser.add_argument(
"--enable-deterministic-inference", "--enable-deterministic-inference",
action="store_true", action="store_true",
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from utils import GeluAndMul, SiluAndMul, precision from utils import GeluAndMul, SiluAndMul, precision
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234) torch.manual_seed(1234)
...@@ -17,6 +18,8 @@ class TestActivation(CustomTestCase): ...@@ -17,6 +18,8 @@ class TestActivation(CustomTestCase):
dtype = [torch.float16, torch.bfloat16] dtype = [torch.float16, torch.bfloat16]
def _silu_and_mul_test(self, m, n, dtype): def _silu_and_mul_test(self, m, n, dtype):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
x = torch.randn([m, n], dtype=dtype) x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
......
...@@ -20,6 +20,7 @@ from utils import ( ...@@ -20,6 +20,7 @@ from utils import (
torch_w8a8_per_column_moe, torch_w8a8_per_column_moe,
) )
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234) torch.manual_seed(1234)
...@@ -149,6 +150,8 @@ class TestSharedExpert(CustomTestCase): ...@@ -149,6 +150,8 @@ class TestSharedExpert(CustomTestCase):
self._int8_shared_expert(*params) self._int8_shared_expert(*params)
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor): def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
dtype = torch.bfloat16 dtype = torch.bfloat16
prepack = True prepack = True
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -96,6 +97,8 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.f ...@@ -96,6 +97,8 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.f
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using native torch.""" """This function performs fused moe with block-wise quantization using native torch."""
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
......
...@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -35,6 +36,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): ...@@ -35,6 +36,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch.""" """This function performs fused moe with per-column int8 quantization using native torch."""
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape B, D = a.shape
# Perform per-token quantization # Perform per-token quantization
a_q, a_s = per_token_quant_int8(a) a_q, a_s = per_token_quant_int8(a)
......
...@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe ...@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -63,6 +64,8 @@ class TestFusedMOE(CustomTestCase): ...@@ -63,6 +64,8 @@ class TestFusedMOE(CustomTestCase):
a1_scale=None, a1_scale=None,
a2_scale=None, a2_scale=None,
): ):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
......
...@@ -9,6 +9,7 @@ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig ...@@ -9,6 +9,7 @@ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsQuantInfo from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsQuantInfo
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -56,6 +57,8 @@ class TestFusedMOE(CustomTestCase): ...@@ -56,6 +57,8 @@ class TestFusedMOE(CustomTestCase):
topk, topk,
return_per_expert: bool = False, return_per_expert: bool = False,
): ):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
......
...@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -40,6 +41,8 @@ def fp8_mask(a, mask): ...@@ -40,6 +41,8 @@ def fp8_mask(a, mask):
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch.""" """This function performs fused moe with per-column int8 quantization using native torch."""
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape B, D = a.shape
# Perform per-token quantization # Perform per-token quantization
a_q, a_s = scaled_fp8_quant(a, use_per_token_if_dynamic=True) a_q, a_s = scaled_fp8_quant(a, use_per_token_if_dynamic=True)
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6] TOP_KS = [2, 6]
...@@ -116,6 +117,8 @@ def quantize_weights( ...@@ -116,6 +117,8 @@ def quantize_weights(
def torch_moe(a, w1, w2, score, topk): def torch_moe(a, w1, w2, score, topk):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment